diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/express/_imshow.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/plotly/express/_imshow.py | 605 |
1 files changed, 605 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/express/_imshow.py b/venv/lib/python3.8/site-packages/plotly/express/_imshow.py new file mode 100644 index 0000000..ce6ddb8 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/_imshow.py @@ -0,0 +1,605 @@ +import plotly.graph_objs as go +from _plotly_utils.basevalidators import ColorscaleValidator +from ._core import apply_default_cascade, init_figure, configure_animation_controls +from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types +import narwhals.stable.v1 as nw +import numpy as np +import itertools +from plotly.utils import image_array_to_data_uri + +try: + import xarray + + xarray_imported = True +except ImportError: + xarray_imported = False + +_float_types = [] + + +def _vectorize_zvalue(z, mode="max"): + alpha = 255 if mode == "max" else 0 + if z is None: + return z + elif np.isscalar(z): + return [z] * 3 + [alpha] + elif len(z) == 1: + return list(z) * 3 + [alpha] + elif len(z) == 3: + return list(z) + [alpha] + elif len(z) == 4: + return z + else: + raise ValueError( + "zmax can be a scalar, or an iterable of length 1, 3 or 4. " + "A value of %s was passed for zmax." % str(z) + ) + + +def _infer_zmax_from_type(img): + dt = img.dtype.type + rtol = 1.05 + if dt in _integer_types: + return _integer_ranges[dt][1] + else: + im_max = img[np.isfinite(img)].max() + if im_max <= 1 * rtol: + return 1 + elif im_max <= 255 * rtol: + return 255 + elif im_max <= 65535 * rtol: + return 65535 + else: + return 2**32 + + +def imshow( + img, + zmin=None, + zmax=None, + origin=None, + labels={}, + x=None, + y=None, + animation_frame=None, + facet_col=None, + facet_col_wrap=None, + facet_col_spacing=None, + facet_row_spacing=None, + color_continuous_scale=None, + color_continuous_midpoint=None, + range_color=None, + title=None, + template=None, + width=None, + height=None, + aspect=None, + contrast_rescaling=None, + binary_string=None, + binary_backend="auto", + binary_compression_level=4, + binary_format="png", + text_auto=False, +) -> go.Figure: + """ + Display an image, i.e. data on a 2D regular raster. + + Parameters + ---------- + + img: array-like image, or xarray + The image data. Supported array shapes are + + - (M, N): an image with scalar data. The data is visualized + using a colormap. + - (M, N, 3): an image with RGB values. + - (M, N, 4): an image with RGBA values, i.e. including transparency. + + zmin, zmax : scalar or iterable, optional + zmin and zmax define the scalar range that the colormap covers. By default, + zmin and zmax correspond to the min and max values of the datatype for integer + datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For + a multichannel image of floats, the max of the image is computed and zmax is the + smallest power of 256 (1, 255, 65535) greater than this max value, + with a 5% tolerance. For a single-channel image, the max of the image is used. + Overridden by range_color. + + origin : str, 'upper' or 'lower' (default 'upper') + position of the [0, 0] pixel of the image array, in the upper left or lower left + corner. The convention 'upper' is typically used for matrices and images. + + labels : dict with str keys and str values (default `{}`) + Sets names used in the figure for axis titles (keys ``x`` and ``y``), + colorbar title and hoverlabel (key ``color``). The values should correspond + to the desired label to be displayed. If ``img`` is an xarray, dimension + names are used for axis titles, and long name for the colorbar title + (unless overridden in ``labels``). Possible keys are: x, y, and color. + + x, y: list-like, optional + x and y are used to label the axes of single-channel heatmap visualizations and + their lengths must match the lengths of the second and first dimensions of the + img argument. They are auto-populated if the input is an xarray. + + animation_frame: int or str, optional (default None) + axis number along which the image array is sliced to create an animation plot. + If `img` is an xarray, `animation_frame` can be the name of one the dimensions. + + facet_col: int or str, optional (default None) + axis number along which the image array is sliced to create a facetted plot. + If `img` is an xarray, `facet_col` can be the name of one the dimensions. + + facet_col_wrap: int + Maximum number of facet columns. Wraps the column variable at this width, + so that the column facets span multiple rows. + Ignored if `facet_col` is None. + + facet_col_spacing: float between 0 and 1 + Spacing between facet columns, in paper units. Default is 0.02. + + facet_row_spacing: float between 0 and 1 + Spacing between facet rows created when ``facet_col_wrap`` is used, in + paper units. Default is 0.0.7. + + color_continuous_scale : str or list of str + colormap used to map scalar data to colors (for a 2D image). This parameter is + not used for RGB or RGBA images. If a string is provided, it should be the name + of a known color scale, and if a list is provided, it should be a list of CSS- + compatible colors. + + color_continuous_midpoint : number + If set, computes the bounds of the continuous color scale to have the desired + midpoint. Overridden by range_color or zmin and zmax. + + range_color : list of two numbers + If provided, overrides auto-scaling on the continuous color scale, including + overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only + for single-channel images. + + title : str + The figure title. + + template : str or dict or plotly.graph_objects.layout.Template instance + The figure template name or definition. + + width : number + The figure width in pixels. + + height: number + The figure height in pixels. + + aspect: 'equal', 'auto', or None + - 'equal': Ensures an aspect ratio of 1 or pixels (square pixels) + - 'auto': The axes is kept fixed and the aspect ratio of pixels is + adjusted so that the data fit in the axes. In general, this will + result in non-square pixels. + - if None, 'equal' is used for numpy arrays and 'auto' for xarrays + (which have typically heterogeneous coordinates) + + contrast_rescaling: 'minmax', 'infer', or None + how to determine data values corresponding to the bounds of the color + range, when zmin or zmax are not passed. If `minmax`, the min and max + values of the image are used. If `infer`, a heuristic based on the image + data type is used. + + binary_string: bool, default None + if True, the image data are first rescaled and encoded as uint8 and + then passed to plotly.js as a b64 PNG string. If False, data are passed + unchanged as a numerical array. Setting to True may lead to performance + gains, at the cost of a loss of precision depending on the original data + type. If None, use_binary_string is set to True for multichannel (eg) RGB + arrays, and to False for single-channel (2D) arrays. 2D arrays are + represented as grayscale and with no colorbar if use_binary_string is + True. + + binary_backend: str, 'auto' (default), 'pil' or 'pypng' + Third-party package for the transformation of numpy arrays to + png b64 strings. If 'auto', Pillow is used if installed, otherwise + pypng. + + binary_compression_level: int, between 0 and 9 (default 4) + png compression level to be passed to the backend when transforming an + array to a png b64 string. Increasing `binary_compression` decreases the + size of the png string, but the compression step takes more time. For most + images it is not worth using levels greater than 5, but it's possible to + test `len(fig.data[0].source)` and to time the execution of `imshow` to + tune the level of compression. 0 means no compression (not recommended). + + binary_format: str, 'png' (default) or 'jpg' + compression format used to generate b64 string. 'png' is recommended + since it uses lossless compression, but 'jpg' (lossy) compression can + result if smaller binary strings for natural images. + + text_auto: bool or str (default `False`) + If `True` or a string, single-channel `img` values will be displayed as text. + A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive. + + Returns + ------- + fig : graph_objects.Figure containing the displayed image + + See also + -------- + + plotly.graph_objects.Image : image trace + plotly.graph_objects.Heatmap : heatmap trace + + Notes + ----- + + In order to update and customize the returned figure, use + `go.Figure.update_traces` or `go.Figure.update_layout`. + + If an xarray is passed, dimensions names and coordinates are used for + axes labels and ticks. + """ + args = locals() + apply_default_cascade(args) + labels = labels.copy() + nslices_facet = 1 + if facet_col is not None: + if isinstance(facet_col, str): + facet_col = img.dims.index(facet_col) + nslices_facet = img.shape[facet_col] + facet_slices = range(nslices_facet) + ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet + nrows = ( + nslices_facet // ncols + 1 + if nslices_facet % ncols + else nslices_facet // ncols + ) + else: + nrows = 1 + ncols = 1 + if animation_frame is not None: + if isinstance(animation_frame, str): + animation_frame = img.dims.index(animation_frame) + nslices_animation = img.shape[animation_frame] + animation_slices = range(nslices_animation) + slice_dimensions = (facet_col is not None) + ( + animation_frame is not None + ) # 0, 1, or 2 + facet_label = None + animation_label = None + img_is_xarray = False + # ----- Define x and y, set labels if img is an xarray ------------------- + if xarray_imported and isinstance(img, xarray.DataArray): + dims = list(img.dims) + img_is_xarray = True + pop_indexes = [] + if facet_col is not None: + facet_slices = img.coords[img.dims[facet_col]].values + pop_indexes.append(facet_col) + facet_label = img.dims[facet_col] + if animation_frame is not None: + animation_slices = img.coords[img.dims[animation_frame]].values + pop_indexes.append(animation_frame) + animation_label = img.dims[animation_frame] + # Remove indices in sorted order. + for index in sorted(pop_indexes, reverse=True): + _ = dims.pop(index) + y_label, x_label = dims[0], dims[1] + # np.datetime64 is not handled correctly by go.Heatmap + for ax in [x_label, y_label]: + if np.issubdtype(img.coords[ax].dtype, np.datetime64): + img.coords[ax] = img.coords[ax].astype(str) + if x is None: + x = img.coords[x_label].values + if y is None: + y = img.coords[y_label].values + if aspect is None: + aspect = "auto" + if labels.get("x", None) is None: + labels["x"] = x_label + if labels.get("y", None) is None: + labels["y"] = y_label + if labels.get("animation_frame", None) is None: + labels["animation_frame"] = animation_label + if labels.get("facet_col", None) is None: + labels["facet_col"] = facet_label + if labels.get("color", None) is None: + labels["color"] = xarray.plot.utils.label_from_attrs(img) + labels["color"] = labels["color"].replace("\n", "<br>") + else: + if hasattr(img, "columns") and hasattr(img.columns, "__len__"): + if x is None: + x = img.columns + if labels.get("x", None) is None and hasattr(img.columns, "name"): + labels["x"] = img.columns.name or "" + if hasattr(img, "index") and hasattr(img.index, "__len__"): + if y is None: + y = img.index + if labels.get("y", None) is None and hasattr(img.index, "name"): + labels["y"] = img.index.name or "" + + if labels.get("x", None) is None: + labels["x"] = "" + if labels.get("y", None) is None: + labels["y"] = "" + if labels.get("color", None) is None: + labels["color"] = "" + if aspect is None: + aspect = "equal" + + # --- Set the value of binary_string (forbidden for pandas) + img = nw.from_native(img, pass_through=True) + if isinstance(img, nw.DataFrame): + if binary_string: + raise ValueError("Binary strings cannot be used with pandas arrays") + is_dataframe = True + else: + is_dataframe = False + + # --------------- Starting from here img is always a numpy array -------- + img = np.asanyarray(img) + # Reshape array so that animation dimension comes first, then facets, then images + if facet_col is not None: + img = np.moveaxis(img, facet_col, 0) + if animation_frame is not None and animation_frame < facet_col: + animation_frame += 1 + facet_col = True + if animation_frame is not None: + img = np.moveaxis(img, animation_frame, 0) + animation_frame = True + args["animation_frame"] = ( + "animation_frame" + if labels.get("animation_frame") is None + else labels["animation_frame"] + ) + iterables = () + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) + + # Default behaviour of binary_string: True for RGB images, False for 2D + if binary_string is None: + binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe + + # Cast bools to uint8 (also one byte) + if img.dtype == bool: + img = 255 * img.astype(np.uint8) + + if range_color is not None: + zmin = range_color[0] + zmax = range_color[1] + + # -------- Contrast rescaling: either minmax or infer ------------------ + if contrast_rescaling is None: + contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer" + + # We try to set zmin and zmax only if necessary, because traces have good defaults + if contrast_rescaling == "minmax": + # When using binary_string and minmax we need to set zmin and zmax to rescale the image + if (zmin is not None or binary_string) and zmax is None: + zmax = img.max() + if (zmax is not None or binary_string) and zmin is None: + zmin = img.min() + else: + # For uint8 data and infer we let zmin and zmax to be None if passed as None + if zmax is None and img.dtype != np.uint8: + zmax = _infer_zmax_from_type(img) + if zmin is None and zmax is not None: + zmin = 0 + + # For 2d data, use Heatmap trace, unless binary_string is True + if img.ndim == 2 + slice_dimensions and not binary_string: + y_index = slice_dimensions + if y is not None and img.shape[y_index] != len(y): + raise ValueError( + "The length of the y vector must match the length of the first " + + "dimension of the img matrix." + ) + x_index = slice_dimensions + 1 + if x is not None and img.shape[x_index] != len(x): + raise ValueError( + "The length of the x vector must match the length of the second " + + "dimension of the img matrix." + ) + + texttemplate = None + if text_auto is True: + texttemplate = "%{z}" + elif text_auto is not False: + texttemplate = "%{z:" + text_auto + "}" + + traces = [ + go.Heatmap( + x=x, + y=y, + z=img[index_tup], + coloraxis="coloraxis1", + name=str(i), + texttemplate=texttemplate, + ) + for i, index_tup in enumerate(itertools.product(*iterables)) + ] + autorange = True if origin == "lower" else "reversed" + layout = dict(yaxis=dict(autorange=autorange)) + if aspect == "equal": + layout["xaxis"] = dict(scaleanchor="y", constrain="domain") + layout["yaxis"]["constrain"] = "domain" + colorscale_validator = ColorscaleValidator("colorscale", "imshow") + layout["coloraxis1"] = dict( + colorscale=colorscale_validator.validate_coerce( + args["color_continuous_scale"] + ), + cmid=color_continuous_midpoint, + cmin=zmin, + cmax=zmax, + ) + if labels["color"]: + layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) + + # For 2D+RGB data, use Image trace + elif ( + img.ndim >= 3 + and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string) + ) or (img.ndim == 2 and binary_string): + rescale_image = True # to check whether image has been modified + if zmin is not None and zmax is not None: + zmin, zmax = ( + _vectorize_zvalue(zmin, mode="min"), + _vectorize_zvalue(zmax, mode="max"), + ) + x0, y0, dx, dy = (None,) * 4 + error_msg_xarray = ( + "Non-numerical coordinates were passed with xarray `img`, but " + "the Image trace cannot handle it. Please use `binary_string=False` " + "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." + ) + if x is not None: + x = np.asanyarray(x) + if np.issubdtype(x.dtype, np.number): + x0 = x[0] + dx = x[1] - x[0] + else: + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `x` parameter " + "when an Image trace is used." + ) + ) + raise ValueError(error_msg) + if y is not None: + y = np.asanyarray(y) + if np.issubdtype(y.dtype, np.number): + y0 = y[0] + dy = y[1] - y[0] + else: + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `y` parameter " + "when an Image trace is used." + ) + ) + raise ValueError(error_msg) + if binary_string: + if zmin is None and zmax is None: # no rescaling, faster + img_rescaled = img + rescale_image = False + elif img.ndim == 2 + slice_dimensions: # single-channel image + img_rescaled = rescale_intensity( + img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 + ) + else: + img_rescaled = np.stack( + [ + rescale_intensity( + img[..., ch], + in_range=(zmin[ch], zmax[ch]), + out_range=np.uint8, + ) + for ch in range(img.shape[-1]) + ], + axis=-1, + ) + img_str = [ + image_array_to_data_uri( + img_rescaled[index_tup], + backend=binary_backend, + compression=binary_compression_level, + ext=binary_format, + ) + for index_tup in itertools.product(*iterables) + ] + + traces = [ + go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy) + for i, img_str_slice in enumerate(img_str) + ] + else: + colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" + traces = [ + go.Image( + z=img[index_tup], + zmin=zmin, + zmax=zmax, + colormodel=colormodel, + x0=x0, + y0=y0, + dx=dx, + dy=dy, + ) + for index_tup in itertools.product(*iterables) + ] + layout = {} + if origin == "lower" or (dy is not None and dy < 0): + layout["yaxis"] = dict(autorange=True) + if dx is not None and dx < 0: + layout["xaxis"] = dict(autorange="reversed") + else: + raise ValueError( + "px.imshow only accepts 2D single-channel, RGB or RGBA images. " + "An image of shape %s was provided. " + "Alternatively, 3- or 4-D single or multichannel datasets can be " + "visualized using the `facet_col` or/and `animation_frame` arguments." + % str(img.shape) + ) + + # Now build figure + col_labels = [] + if facet_col is not None: + slice_label = ( + "facet_col" if labels.get("facet_col") is None else labels["facet_col"] + ) + col_labels = [f"{slice_label}={i}" for i in facet_slices] + fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) + for attr_name in ["height", "width"]: + if args[attr_name]: + layout[attr_name] = args[attr_name] + if args["title"]: + layout["title_text"] = args["title"] + elif args["template"].layout.margin.t is None: + layout["margin"] = {"t": 60} + + frame_list = [] + for index, trace in enumerate(traces): + if (facet_col and index < nrows * ncols) or index == 0: + fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) + if animation_frame is not None: + for i, index in zip(range(nslices_animation), animation_slices): + frame_list.append( + dict( + data=traces[nslices_facet * i : nslices_facet * (i + 1)], + layout=layout, + name=str(index), + ) + ) + if animation_frame: + fig.frames = frame_list + fig.update_layout(layout) + # Hover name, z or color + if binary_string and rescale_image and not np.all(img == img_rescaled): + # we rescaled the image, hence z is not displayed in hover since it does + # not correspond to img values + hovertemplate = "%s: %%{x}<br>%s: %%{y}<extra></extra>" % ( + labels["x"] or "x", + labels["y"] or "y", + ) + else: + if trace["type"] == "heatmap": + hover_name = "%{z}" + elif img.ndim == 2: + hover_name = "%{z[0]}" + elif img.ndim == 3 and img.shape[-1] == 3: + hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]" + else: + hover_name = "%{z}" + hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %s<extra></extra>" % ( + labels["x"] or "x", + labels["y"] or "y", + labels["color"] or "color", + hover_name, + ) + fig.update_traces(hovertemplate=hovertemplate) + if labels["x"]: + fig.update_xaxes(title_text=labels["x"], row=1) + if labels["y"]: + fig.update_yaxes(title_text=labels["y"], col=1) + configure_animation_controls(args, go.Image, fig) + fig.update_layout(template=args["template"], overwrite=True) + return fig |