diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter')
12 files changed, 1796 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py new file mode 100644 index 0000000..296a47e --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa: F401 +from .renderers import Renderer +from .exporter import Exporter diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py new file mode 100644 index 0000000..bbd1756 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py @@ -0,0 +1,317 @@ +""" +Matplotlib Exporter +=================== +This submodule contains tools for crawling a matplotlib figure and exporting +relevant pieces to a renderer. +""" + +import warnings +import io +from . import utils + +import matplotlib +from matplotlib import transforms +from matplotlib.backends.backend_agg import FigureCanvasAgg + + +class Exporter(object): + """Matplotlib Exporter + + Parameters + ---------- + renderer : Renderer object + The renderer object called by the exporter to create a figure + visualization. See mplexporter.Renderer for information on the + methods which should be defined within the renderer. + close_mpl : bool + If True (default), close the matplotlib figure as it is rendered. This + is useful for when the exporter is used within the notebook, or with + an interactive matplotlib backend. + """ + + def __init__(self, renderer, close_mpl=True): + self.close_mpl = close_mpl + self.renderer = renderer + + def run(self, fig): + """ + Run the exporter on the given figure + + Parmeters + --------- + fig : matplotlib.Figure instance + The figure to export + """ + # Calling savefig executes the draw() command, putting elements + # in the correct place. + if fig.canvas is None: + FigureCanvasAgg(fig) + fig.savefig(io.BytesIO(), format="png", dpi=fig.dpi) + if self.close_mpl: + import matplotlib.pyplot as plt + + plt.close(fig) + self.crawl_fig(fig) + + @staticmethod + def process_transform( + transform, ax=None, data=None, return_trans=False, force_trans=None + ): + """Process the transform and convert data to figure or data coordinates + + Parameters + ---------- + transform : matplotlib Transform object + The transform applied to the data + ax : matplotlib Axes object (optional) + The axes the data is associated with + data : ndarray (optional) + The array of data to be transformed. + return_trans : bool (optional) + If true, return the final transform of the data + force_trans : matplotlib.transform instance (optional) + If supplied, first force the data to this transform + + Returns + ------- + code : string + Code is either "data", "axes", "figure", or "display", indicating + the type of coordinates output. + transform : matplotlib transform + the transform used to map input data to output data. + Returned only if return_trans is True + new_data : ndarray + Data transformed to match the given coordinate code. + Returned only if data is specified + """ + if isinstance(transform, transforms.BlendedGenericTransform): + warnings.warn( + "Blended transforms not yet supported. " + "Zoom behavior may not work as expected." + ) + + if force_trans is not None: + if data is not None: + data = (transform - force_trans).transform(data) + transform = force_trans + + code = "display" + if ax is not None: + for c, trans in [ + ("data", ax.transData), + ("axes", ax.transAxes), + ("figure", ax.figure.transFigure), + ("display", transforms.IdentityTransform()), + ]: + if transform.contains_branch(trans): + code, transform = (c, transform - trans) + break + + if data is not None: + if return_trans: + return code, transform.transform(data), transform + else: + return code, transform.transform(data) + else: + if return_trans: + return code, transform + else: + return code + + def crawl_fig(self, fig): + """Crawl the figure and process all axes""" + with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)): + for ax in fig.axes: + self.crawl_ax(ax) + + def crawl_ax(self, ax): + """Crawl the axes and process all elements within""" + with self.renderer.draw_axes(ax=ax, props=utils.get_axes_properties(ax)): + for line in ax.lines: + self.draw_line(ax, line) + for text in ax.texts: + self.draw_text(ax, text) + for text, ttp in zip( + [ax.xaxis.label, ax.yaxis.label, ax.title], + ["xlabel", "ylabel", "title"], + ): + if hasattr(text, "get_text") and text.get_text(): + self.draw_text(ax, text, force_trans=ax.transAxes, text_type=ttp) + for artist in ax.artists: + # TODO: process other artists + if isinstance(artist, matplotlib.text.Text): + self.draw_text(ax, artist) + for patch in ax.patches: + self.draw_patch(ax, patch) + for collection in ax.collections: + self.draw_collection(ax, collection) + for image in ax.images: + self.draw_image(ax, image) + + legend = ax.get_legend() + if legend is not None: + props = utils.get_legend_properties(ax, legend) + with self.renderer.draw_legend(legend=legend, props=props): + if props["visible"]: + self.crawl_legend(ax, legend) + + def crawl_legend(self, ax, legend): + """ + Recursively look through objects in legend children + """ + legendElements = list( + utils.iter_all_children(legend._legend_box, skipContainers=True) + ) + legendElements.append(legend.legendPatch) + for child in legendElements: + # force a large zorder so it appears on top + child.set_zorder(1e6 + child.get_zorder()) + + # reorder border box to make sure marks are visible + if isinstance(child, matplotlib.patches.FancyBboxPatch): + child.set_zorder(child.get_zorder() - 1) + + try: + # What kind of object... + if isinstance(child, matplotlib.patches.Patch): + self.draw_patch(ax, child, force_trans=ax.transAxes) + elif isinstance(child, matplotlib.text.Text): + if child.get_text() != "None": + self.draw_text(ax, child, force_trans=ax.transAxes) + elif isinstance(child, matplotlib.lines.Line2D): + self.draw_line(ax, child, force_trans=ax.transAxes) + elif isinstance(child, matplotlib.collections.Collection): + self.draw_collection(ax, child, force_pathtrans=ax.transAxes) + else: + warnings.warn("Legend element %s not impemented" % child) + except NotImplementedError: + warnings.warn("Legend element %s not impemented" % child) + + def draw_line(self, ax, line, force_trans=None): + """Process a matplotlib line and call renderer.draw_line""" + coordinates, data = self.process_transform( + line.get_transform(), ax, line.get_xydata(), force_trans=force_trans + ) + linestyle = utils.get_line_style(line) + if linestyle["dasharray"] is None and linestyle["drawstyle"] == "default": + linestyle = None + markerstyle = utils.get_marker_style(line) + if ( + markerstyle["marker"] in ["None", "none", None] + or markerstyle["markerpath"][0].size == 0 + ): + markerstyle = None + label = line.get_label() + if markerstyle or linestyle: + self.renderer.draw_marked_line( + data=data, + coordinates=coordinates, + linestyle=linestyle, + markerstyle=markerstyle, + label=label, + mplobj=line, + ) + + def draw_text(self, ax, text, force_trans=None, text_type=None): + """Process a matplotlib text object and call renderer.draw_text""" + content = text.get_text() + if content: + transform = text.get_transform() + position = text.get_position() + coords, position = self.process_transform( + transform, ax, position, force_trans=force_trans + ) + style = utils.get_text_style(text) + self.renderer.draw_text( + text=content, + position=position, + coordinates=coords, + text_type=text_type, + style=style, + mplobj=text, + ) + + def draw_patch(self, ax, patch, force_trans=None): + """Process a matplotlib patch object and call renderer.draw_path""" + vertices, pathcodes = utils.SVG_path(patch.get_path()) + transform = patch.get_transform() + coordinates, vertices = self.process_transform( + transform, ax, vertices, force_trans=force_trans + ) + linestyle = utils.get_path_style(patch, fill=patch.get_fill()) + self.renderer.draw_path( + data=vertices, + coordinates=coordinates, + pathcodes=pathcodes, + style=linestyle, + mplobj=patch, + ) + + def draw_collection( + self, ax, collection, force_pathtrans=None, force_offsettrans=None + ): + """Process a matplotlib collection and call renderer.draw_collection""" + (transform, transOffset, offsets, paths) = collection._prepare_points() + + offset_coords, offsets = self.process_transform( + transOffset, ax, offsets, force_trans=force_offsettrans + ) + path_coords = self.process_transform(transform, ax, force_trans=force_pathtrans) + + processed_paths = [utils.SVG_path(path) for path in paths] + processed_paths = [ + ( + self.process_transform( + transform, ax, path[0], force_trans=force_pathtrans + )[1], + path[1], + ) + for path in processed_paths + ] + + path_transforms = collection.get_transforms() + try: + # matplotlib 1.3: path_transforms are transform objects. + # Convert them to numpy arrays. + path_transforms = [t.get_matrix() for t in path_transforms] + except AttributeError: + # matplotlib 1.4: path transforms are already numpy arrays. + pass + + styles = { + "linewidth": collection.get_linewidths(), + "facecolor": collection.get_facecolors(), + "edgecolor": collection.get_edgecolors(), + "alpha": collection._alpha, + "zorder": collection.get_zorder(), + } + + # TODO: When matplotlib's minimum version is bumped to 3.8, this can be + # simplified since collection.get_offset_position no longer exists. + offset_dict = {"data": "before", "screen": "after"} + offset_order = ( + offset_dict[collection.get_offset_position()] + if hasattr(collection, "get_offset_position") + else "after" + ) + + self.renderer.draw_path_collection( + paths=processed_paths, + path_coordinates=path_coords, + path_transforms=path_transforms, + offsets=offsets, + offset_coordinates=offset_coords, + offset_order=offset_order, + styles=styles, + mplobj=collection, + ) + + def draw_image(self, ax, image): + """Process a matplotlib image object and call renderer.draw_image""" + self.renderer.draw_image( + imdata=utils.image_to_base64(image), + extent=image.get_extent(), + coordinates="data", + style={"alpha": image.get_alpha(), "zorder": image.get_zorder()}, + mplobj=image, + ) diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py new file mode 100644 index 0000000..21113ad --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py @@ -0,0 +1,14 @@ +# ruff: noqa F401 + +""" +Matplotlib Renderers +==================== +This submodule contains renderer objects which define renderer behavior used +within the Exporter class. The base renderer class is :class:`Renderer`, an +abstract base class +""" + +from .base import Renderer +from .vega_renderer import VegaRenderer, fig_to_vega +from .vincent_renderer import VincentRenderer, fig_to_vincent +from .fake_renderer import FakeRenderer, FullFakeRenderer diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py new file mode 100644 index 0000000..fbb8819 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py @@ -0,0 +1,428 @@ +import warnings +import itertools +from contextlib import contextmanager +from packaging.version import Version + +import numpy as np +import matplotlib as mpl +from matplotlib import transforms + +from .. import utils + + +class Renderer(object): + @staticmethod + def ax_zoomable(ax): + return bool(ax and ax.get_navigate()) + + @staticmethod + def ax_has_xgrid(ax): + return bool(ax and ax.xaxis._gridOnMajor and ax.yaxis.get_gridlines()) + + @staticmethod + def ax_has_ygrid(ax): + return bool(ax and ax.yaxis._gridOnMajor and ax.yaxis.get_gridlines()) + + @property + def current_ax_zoomable(self): + return self.ax_zoomable(self._current_ax) + + @property + def current_ax_has_xgrid(self): + return self.ax_has_xgrid(self._current_ax) + + @property + def current_ax_has_ygrid(self): + return self.ax_has_ygrid(self._current_ax) + + @contextmanager + def draw_figure(self, fig, props): + if hasattr(self, "_current_fig") and self._current_fig is not None: + warnings.warn("figure embedded in figure: something is wrong") + self._current_fig = fig + self._fig_props = props + self.open_figure(fig=fig, props=props) + yield + self.close_figure(fig=fig) + self._current_fig = None + self._fig_props = {} + + @contextmanager + def draw_axes(self, ax, props): + if hasattr(self, "_current_ax") and self._current_ax is not None: + warnings.warn("axes embedded in axes: something is wrong") + self._current_ax = ax + self._ax_props = props + self.open_axes(ax=ax, props=props) + yield + self.close_axes(ax=ax) + self._current_ax = None + self._ax_props = {} + + @contextmanager + def draw_legend(self, legend, props): + self._current_legend = legend + self._legend_props = props + self.open_legend(legend=legend, props=props) + yield + self.close_legend(legend=legend) + self._current_legend = None + self._legend_props = {} + + # Following are the functions which should be overloaded in subclasses + + def open_figure(self, fig, props): + """ + Begin commands for a particular figure. + + Parameters + ---------- + fig : matplotlib.Figure + The Figure which will contain the ensuing axes and elements + props : dictionary + The dictionary of figure properties + """ + pass + + def close_figure(self, fig): + """ + Finish commands for a particular figure. + + Parameters + ---------- + fig : matplotlib.Figure + The figure which is finished being drawn. + """ + pass + + def open_axes(self, ax, props): + """ + Begin commands for a particular axes. + + Parameters + ---------- + ax : matplotlib.Axes + The Axes which will contain the ensuing axes and elements + props : dictionary + The dictionary of axes properties + """ + pass + + def close_axes(self, ax): + """ + Finish commands for a particular axes. + + Parameters + ---------- + ax : matplotlib.Axes + The Axes which is finished being drawn. + """ + pass + + def open_legend(self, legend, props): + """ + Beging commands for a particular legend. + + Parameters + ---------- + legend : matplotlib.legend.Legend + The Legend that will contain the ensuing elements + props : dictionary + The dictionary of legend properties + """ + pass + + def close_legend(self, legend): + """ + Finish commands for a particular legend. + + Parameters + ---------- + legend : matplotlib.legend.Legend + The Legend which is finished being drawn + """ + pass + + def draw_marked_line( + self, data, coordinates, linestyle, markerstyle, label, mplobj=None + ): + """Draw a line that also has markers. + + If this isn't reimplemented by a renderer object, by default, it will + make a call to BOTH draw_line and draw_markers when both markerstyle + and linestyle are not None in the same Line2D object. + + """ + if linestyle is not None: + self.draw_line(data, coordinates, linestyle, label, mplobj) + if markerstyle is not None: + self.draw_markers(data, coordinates, markerstyle, label, mplobj) + + def draw_line(self, data, coordinates, style, label, mplobj=None): + """ + Draw a line. By default, draw the line via the draw_path() command. + Some renderers might wish to override this and provide more + fine-grained behavior. + + In matplotlib, lines are generally created via the plt.plot() command, + though this command also can create marker collections. + + Parameters + ---------- + data : array_like + A shape (N, 2) array of datapoints. + coordinates : string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the line. + mplobj : matplotlib object + the matplotlib plot element which generated this line + """ + pathcodes = ["M"] + (data.shape[0] - 1) * ["L"] + pathstyle = dict(facecolor="none", **style) + pathstyle["edgecolor"] = pathstyle.pop("color") + pathstyle["edgewidth"] = pathstyle.pop("linewidth") + self.draw_path( + data=data, + coordinates=coordinates, + pathcodes=pathcodes, + style=pathstyle, + mplobj=mplobj, + ) + + @staticmethod + def _iter_path_collection(paths, path_transforms, offsets, styles): + """Build an iterator over the elements of the path collection""" + N = max(len(paths), len(offsets)) + + # Before mpl 1.4.0, path_transform can be a false-y value, not a valid + # transformation matrix. + if Version(mpl.__version__) < Version("1.4.0"): + if path_transforms is None: + path_transforms = [np.eye(3)] + + edgecolor = styles["edgecolor"] + if np.size(edgecolor) == 0: + edgecolor = ["none"] + facecolor = styles["facecolor"] + if np.size(facecolor) == 0: + facecolor = ["none"] + + elements = [ + paths, + path_transforms, + offsets, + edgecolor, + styles["linewidth"], + facecolor, + ] + + it = itertools + return it.islice(zip(*map(it.cycle, elements)), N) + + def draw_path_collection( + self, + paths, + path_coordinates, + path_transforms, + offsets, + offset_coordinates, + offset_order, + styles, + mplobj=None, + ): + """ + Draw a collection of paths. The paths, offsets, and styles are all + iterables, and the number of paths is max(len(paths), len(offsets)). + + By default, this is implemented via multiple calls to the draw_path() + function. For efficiency, Renderers may choose to customize this + implementation. + + Examples of path collections created by matplotlib are scatter plots, + histograms, contour plots, and many others. + + Parameters + ---------- + paths : list + list of tuples, where each tuple has two elements: + (data, pathcodes). See draw_path() for a description of these. + path_coordinates: string + the coordinates code for the paths, which should be either + 'data' for data coordinates, or 'figure' for figure (pixel) + coordinates. + path_transforms: array_like + an array of shape (*, 3, 3), giving a series of 2D Affine + transforms for the paths. These encode translations, rotations, + and scalings in the standard way. + offsets: array_like + An array of offsets of shape (N, 2) + offset_coordinates : string + the coordinates code for the offsets, which should be either + 'data' for data coordinates, or 'figure' for figure (pixel) + coordinates. + offset_order : string + either "before" or "after". This specifies whether the offset + is applied before the path transform, or after. The matplotlib + backend equivalent is "before"->"data", "after"->"screen". + styles: dictionary + A dictionary in which each value is a list of length N, containing + the style(s) for the paths. + mplobj : matplotlib object + the matplotlib plot element which generated this collection + """ + if offset_order == "before": + raise NotImplementedError("offset before transform") + + for tup in self._iter_path_collection(paths, path_transforms, offsets, styles): + (path, path_transform, offset, ec, lw, fc) = tup + vertices, pathcodes = path + path_transform = transforms.Affine2D(path_transform) + vertices = path_transform.transform(vertices) + # This is a hack: + if path_coordinates == "figure": + path_coordinates = "points" + style = { + "edgecolor": utils.export_color(ec), + "facecolor": utils.export_color(fc), + "edgewidth": lw, + "dasharray": "10,0", + "alpha": styles["alpha"], + "zorder": styles["zorder"], + } + self.draw_path( + data=vertices, + coordinates=path_coordinates, + pathcodes=pathcodes, + style=style, + offset=offset, + offset_coordinates=offset_coordinates, + mplobj=mplobj, + ) + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + """ + Draw a set of markers. By default, this is done by repeatedly + calling draw_path(), but renderers should generally overload + this method to provide a more efficient implementation. + + In matplotlib, markers are created using the plt.plot() command. + + Parameters + ---------- + data : array_like + A shape (N, 2) array of datapoints. + coordinates : string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the markers. + mplobj : matplotlib object + the matplotlib plot element which generated this marker collection + """ + vertices, pathcodes = style["markerpath"] + pathstyle = dict( + (key, style[key]) + for key in ["alpha", "edgecolor", "facecolor", "zorder", "edgewidth"] + ) + pathstyle["dasharray"] = "10,0" + for vertex in data: + self.draw_path( + data=vertices, + coordinates="points", + pathcodes=pathcodes, + style=pathstyle, + offset=vertex, + offset_coordinates=coordinates, + mplobj=mplobj, + ) + + def draw_text( + self, text, position, coordinates, style, text_type=None, mplobj=None + ): + """ + Draw text on the image. + + Parameters + ---------- + text : string + The text to draw + position : tuple + The (x, y) position of the text + coordinates : string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the text. + text_type : string or None + if specified, a type of text such as "xlabel", "ylabel", "title" + mplobj : matplotlib object + the matplotlib plot element which generated this text + """ + raise NotImplementedError() + + def draw_path( + self, + data, + coordinates, + pathcodes, + style, + offset=None, + offset_coordinates="data", + mplobj=None, + ): + """ + Draw a path. + + In matplotlib, paths are created by filled regions, histograms, + contour plots, patches, etc. + + Parameters + ---------- + data : array_like + A shape (N, 2) array of datapoints. + coordinates : string + A string code, which should be either 'data' for data coordinates, + 'figure' for figure (pixel) coordinates, or "points" for raw + point coordinates (useful in conjunction with offsets, below). + pathcodes : list + A list of single-character SVG pathcodes associated with the data. + Path codes are one of ['M', 'm', 'L', 'l', 'Q', 'q', 'T', 't', + 'S', 's', 'C', 'c', 'Z', 'z'] + See the SVG specification for details. Note that some path codes + consume more than one datapoint (while 'Z' consumes none), so + in general, the length of the pathcodes list will not be the same + as that of the data array. + style : dictionary + a dictionary specifying the appearance of the line. + offset : list (optional) + the (x, y) offset of the path. If not given, no offset will + be used. + offset_coordinates : string (optional) + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + mplobj : matplotlib object + the matplotlib plot element which generated this path + """ + raise NotImplementedError() + + def draw_image(self, imdata, extent, coordinates, style, mplobj=None): + """ + Draw an image. + + Parameters + ---------- + imdata : string + base64 encoded png representation of the image + extent : list + the axes extent of the image: [xmin, xmax, ymin, ymax] + coordinates: string + A string code, which should be either 'data' for data coordinates, + or 'figure' for figure (pixel) coordinates. + style : dictionary + a dictionary specifying the appearance of the image + mplobj : matplotlib object + the matplotlib plot object which generated this image + """ + raise NotImplementedError() diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py new file mode 100644 index 0000000..de2ae40 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py @@ -0,0 +1,88 @@ +from .base import Renderer + + +class FakeRenderer(Renderer): + """ + Fake Renderer + + This is a fake renderer which simply outputs a text tree representing the + elements found in the plot(s). This is used in the unit tests for the + package. + + Below are the methods your renderer must implement. You are free to do + anything you wish within the renderer (i.e. build an XML or JSON + representation, call an external API, etc.) Here the renderer just + builds a simple string representation for testing purposes. + """ + + def __init__(self): + self.output = "" + + def open_figure(self, fig, props): + self.output += "opening figure\n" + + def close_figure(self, fig): + self.output += "closing figure\n" + + def open_axes(self, ax, props): + self.output += " opening axes\n" + + def close_axes(self, ax): + self.output += " closing axes\n" + + def open_legend(self, legend, props): + self.output += " opening legend\n" + + def close_legend(self, legend): + self.output += " closing legend\n" + + def draw_text( + self, text, position, coordinates, style, text_type=None, mplobj=None + ): + self.output += " draw text '{0}' {1}\n".format(text, text_type) + + def draw_path( + self, + data, + coordinates, + pathcodes, + style, + offset=None, + offset_coordinates="data", + mplobj=None, + ): + self.output += " draw path with {0} vertices\n".format(data.shape[0]) + + def draw_image(self, imdata, extent, coordinates, style, mplobj=None): + self.output += " draw image of size {0}\n".format(len(imdata)) + + +class FullFakeRenderer(FakeRenderer): + """ + Renderer with the full complement of methods. + + When the following are left undefined, they will be implemented via + other methods in the class. They can be defined explicitly for + more efficient or specialized use within the renderer implementation. + """ + + def draw_line(self, data, coordinates, style, label, mplobj=None): + self.output += " draw line with {0} points\n".format(data.shape[0]) + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + self.output += " draw {0} markers\n".format(data.shape[0]) + + def draw_path_collection( + self, + paths, + path_coordinates, + path_transforms, + offsets, + offset_coordinates, + offset_order, + styles, + mplobj=None, + ): + self.output += " draw path collection with {0} offsets\n".format( + offsets.shape[0] + ) diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py new file mode 100644 index 0000000..eab02e1 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py @@ -0,0 +1,155 @@ +import warnings +import json +import random +from .base import Renderer +from ..exporter import Exporter + + +class VegaRenderer(Renderer): + def open_figure(self, fig, props): + self.props = props + self.figwidth = int(props["figwidth"] * props["dpi"]) + self.figheight = int(props["figheight"] * props["dpi"]) + self.data = [] + self.scales = [] + self.axes = [] + self.marks = [] + + def open_axes(self, ax, props): + if len(self.axes) > 0: + warnings.warn("multiple axes not yet supported") + self.axes = [ + dict(type="x", scale="x", ticks=10), + dict(type="y", scale="y", ticks=10), + ] + self.scales = [ + dict( + name="x", + domain=props["xlim"], + type="linear", + range="width", + ), + dict( + name="y", + domain=props["ylim"], + type="linear", + range="height", + ), + ] + + def draw_line(self, data, coordinates, style, label, mplobj=None): + if coordinates != "data": + warnings.warn("Only data coordinates supported. Skipping this") + dataname = "table{0:03d}".format(len(self.data) + 1) + + # TODO: respect the other style settings + self.data.append( + {"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]} + ) + self.marks.append( + { + "type": "line", + "from": {"data": dataname}, + "properties": { + "enter": { + "interpolate": {"value": "monotone"}, + "x": {"scale": "x", "field": "data.x"}, + "y": {"scale": "y", "field": "data.y"}, + "stroke": {"value": style["color"]}, + "strokeOpacity": {"value": style["alpha"]}, + "strokeWidth": {"value": style["linewidth"]}, + } + }, + } + ) + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + if coordinates != "data": + warnings.warn("Only data coordinates supported. Skipping this") + dataname = "table{0:03d}".format(len(self.data) + 1) + + # TODO: respect the other style settings + self.data.append( + {"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]} + ) + self.marks.append( + { + "type": "symbol", + "from": {"data": dataname}, + "properties": { + "enter": { + "interpolate": {"value": "monotone"}, + "x": {"scale": "x", "field": "data.x"}, + "y": {"scale": "y", "field": "data.y"}, + "fill": {"value": style["facecolor"]}, + "fillOpacity": {"value": style["alpha"]}, + "stroke": {"value": style["edgecolor"]}, + "strokeOpacity": {"value": style["alpha"]}, + "strokeWidth": {"value": style["edgewidth"]}, + } + }, + } + ) + + def draw_text( + self, text, position, coordinates, style, text_type=None, mplobj=None + ): + if text_type == "xlabel": + self.axes[0]["title"] = text + elif text_type == "ylabel": + self.axes[1]["title"] = text + + +class VegaHTML(object): + def __init__(self, renderer): + self.specification = dict( + width=renderer.figwidth, + height=renderer.figheight, + data=renderer.data, + scales=renderer.scales, + axes=renderer.axes, + marks=renderer.marks, + ) + + def html(self): + """Build the HTML representation for IPython.""" + id = random.randint(0, 2**16) + html = '<div id="vis%d"></div>' % id + html += "<script>\n" + html += VEGA_TEMPLATE % (json.dumps(self.specification), id) + html += "</script>\n" + return html + + def _repr_html_(self): + return self.html() + + +def fig_to_vega(fig, notebook=False): + """Convert a matplotlib figure to vega dictionary + + if notebook=True, then return an object which will display in a notebook + otherwise, return an HTML string. + """ + renderer = VegaRenderer() + Exporter(renderer).run(fig) + vega_html = VegaHTML(renderer) + if notebook: + return vega_html + else: + return vega_html.html() + + +VEGA_TEMPLATE = """ +( function() { + var _do_plot = function() { + if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) { + $([IPython.events]).on("vega_loaded.vincent", _do_plot); + return; + } + vg.parse.spec(%s, function(chart) { + chart({el: "#vis%d"}).update(); + }); + }; + _do_plot(); +})(); +""" diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py new file mode 100644 index 0000000..36074f6 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py @@ -0,0 +1,54 @@ +import warnings +from .base import Renderer +from ..exporter import Exporter + + +class VincentRenderer(Renderer): + def open_figure(self, fig, props): + self.chart = None + self.figwidth = int(props["figwidth"] * props["dpi"]) + self.figheight = int(props["figheight"] * props["dpi"]) + + def draw_line(self, data, coordinates, style, label, mplobj=None): + import vincent # only import if VincentRenderer is used + + if coordinates != "data": + warnings.warn("Only data coordinates supported. Skipping this") + linedata = {"x": data[:, 0], "y": data[:, 1]} + line = vincent.Line( + linedata, iter_idx="x", width=self.figwidth, height=self.figheight + ) + + # TODO: respect the other style settings + line.scales["color"].range = [style["color"]] + + if self.chart is None: + self.chart = line + else: + warnings.warn("Multiple plot elements not yet supported") + + def draw_markers(self, data, coordinates, style, label, mplobj=None): + import vincent # only import if VincentRenderer is used + + if coordinates != "data": + warnings.warn("Only data coordinates supported. Skipping this") + markerdata = {"x": data[:, 0], "y": data[:, 1]} + markers = vincent.Scatter( + markerdata, iter_idx="x", width=self.figwidth, height=self.figheight + ) + + # TODO: respect the other style settings + markers.scales["color"].range = [style["facecolor"]] + + if self.chart is None: + self.chart = markers + else: + warnings.warn("Multiple plot elements not yet supported") + + +def fig_to_vincent(fig): + """Convert a matplotlib figure to a vincent object""" + renderer = VincentRenderer() + exporter = Exporter(renderer) + exporter.run(fig) + return renderer.chart diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py new file mode 100644 index 0000000..290cc21 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py @@ -0,0 +1,3 @@ +import matplotlib + +matplotlib.use("Agg") diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py new file mode 100644 index 0000000..3739e13 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py @@ -0,0 +1,257 @@ +import matplotlib +import numpy as np +import pytest +from packaging.version import Version + +from ..exporter import Exporter +from ..renderers import FakeRenderer, FullFakeRenderer +import matplotlib.pyplot as plt + + +def fake_renderer_output(fig, Renderer): + renderer = Renderer() + exporter = Exporter(renderer) + exporter.run(fig) + return renderer.output + + +def _assert_output_equal(text1, text2): + for line1, line2 in zip(text1.strip().split(), text2.strip().split()): + assert line1 == line2 + + +def test_lines(): + fig, ax = plt.subplots() + ax.plot(range(20), "-k") + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 20 vertices + closing axes + closing figure + """, + ) + + _assert_output_equal( + fake_renderer_output(fig, FullFakeRenderer), + """ + opening figure + opening axes + draw line with 20 points + closing axes + closing figure + """, + ) + + +def test_markers(): + fig, ax = plt.subplots() + ax.plot(range(2), "ok") + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 25 vertices + draw path with 25 vertices + closing axes + closing figure + """, + ) + + _assert_output_equal( + fake_renderer_output(fig, FullFakeRenderer), + """ + opening figure + opening axes + draw 2 markers + closing axes + closing figure + """, + ) + + +def test_path_collection(): + fig, ax = plt.subplots() + ax.scatter(range(3), range(3)) + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 25 vertices + draw path with 25 vertices + draw path with 25 vertices + closing axes + closing figure + """, + ) + + _assert_output_equal( + fake_renderer_output(fig, FullFakeRenderer), + """ + opening figure + opening axes + draw path collection with 3 offsets + closing axes + closing figure + """, + ) + + +def test_text(): + fig, ax = plt.subplots() + ax.set_xlabel("my x label") + ax.set_ylabel("my y label") + ax.set_title("my title") + ax.text(0.5, 0.5, "my text") + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw text 'my text' None + draw text 'my x label' xlabel + draw text 'my y label' ylabel + draw text 'my title' title + closing axes + closing figure + """, + ) + + +def test_path(): + fig, ax = plt.subplots() + ax.add_patch(plt.Circle((0, 0), 1)) + ax.add_patch(plt.Rectangle((0, 0), 1, 2)) + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 25 vertices + draw path with 4 vertices + closing axes + closing figure + """, + ) + + +def test_Figure(): + """if the fig is not associated with a canvas, FakeRenderer shall + not fail.""" + fig = plt.Figure() + ax = fig.add_subplot(111) + ax.add_patch(plt.Circle((0, 0), 1)) + ax.add_patch(plt.Rectangle((0, 0), 1, 2)) + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 25 vertices + draw path with 4 vertices + closing axes + closing figure + """, + ) + + +def test_multiaxes(): + fig, ax = plt.subplots(2) + ax[0].plot(range(4)) + ax[1].plot(range(10)) + + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 4 vertices + closing axes + opening axes + draw path with 10 vertices + closing axes + closing figure + """, + ) + + +def test_image(): + # Test fails for matplotlib 1.5+ because the size of the image + # generated by matplotlib has changed. + if Version(matplotlib.__version__) == Version("3.4.1"): + image_size = 432 + else: + pytest.skip("Test fails for older matplotlib") + np.random.seed(0) # image size depends on the seed + fig, ax = plt.subplots(figsize=(2, 2)) + ax.imshow(np.random.random((10, 10)), cmap=plt.cm.jet, interpolation="nearest") + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + f""" + opening figure + opening axes + draw image of size {image_size} + closing axes + closing figure + """, + ) + + +def test_legend(): + fig, ax = plt.subplots() + ax.plot([1, 2, 3], label="label") + ax.legend().set_visible(False) + _assert_output_equal( + fake_renderer_output(fig, FakeRenderer), + """ + opening figure + opening axes + draw path with 3 vertices + opening legend + closing legend + closing axes + closing figure + """, + ) + + +def test_legend_dots(): + fig, ax = plt.subplots() + ax.plot([1, 2, 3], label="label") + ax.plot([2, 2, 2], "o", label="dots") + ax.legend().set_visible(True) + # legend draws 1 line and 1 marker + # path around legend now has 13 vertices?? + _assert_output_equal( + fake_renderer_output(fig, FullFakeRenderer), + """ + opening figure + opening axes + draw line with 3 points + draw 3 markers + opening legend + draw line with 2 points + draw text 'label' None + draw 1 markers + draw text 'dots' None + draw path with 13 vertices + closing legend + closing axes + closing figure + """, + ) + + +def test_blended(): + fig, ax = plt.subplots() + ax.axvline(0) diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py new file mode 100644 index 0000000..5659163 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py @@ -0,0 +1,40 @@ +from numpy.testing import assert_allclose, assert_equal +from . import plt +from .. import utils + + +def test_path_data(): + circle = plt.Circle((0, 0), 1) + vertices, codes = utils.SVG_path(circle.get_path()) + + assert_allclose(vertices.shape, (25, 2)) + assert_equal(codes, ["M", "C", "C", "C", "C", "C", "C", "C", "C", "Z"]) + + +def test_linestyle(): + linestyles = { + "solid": "none", + "-": "none", + "dashed": "5.550000000000001,2.4000000000000004", + "--": "5.550000000000001,2.4000000000000004", + "dotted": "1.5,2.4749999999999996", + ":": "1.5,2.4749999999999996", + "dashdot": "9.600000000000001,2.4000000000000004,1.5,2.4000000000000004", + "-.": "9.600000000000001,2.4000000000000004,1.5,2.4000000000000004", + "": None, + "None": None, + } + + for ls, result in linestyles.items(): + (line,) = plt.plot([1, 2, 3], linestyle=ls) + assert_equal(utils.get_dasharray(line), result) + + +def test_axis_w_fixed_formatter(): + positions, labels = [0, 1, 10], ["A", "B", "C"] + + plt.xticks(positions, labels) + props = utils.get_axis_properties(plt.gca().xaxis) + + assert_equal(props["tickvalues"], positions) + assert_equal(props["tickformat"], labels) diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py new file mode 100644 index 0000000..f66fdfb --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py @@ -0,0 +1,55 @@ +""" +Tools for matplotlib plot exporting +""" + + +def ipynb_vega_init(): + """Initialize the IPython notebook display elements + + This function borrows heavily from the excellent vincent package: + http://github.com/wrobstory/vincent + """ + try: + from IPython.core.display import display, HTML + except ImportError: + print("IPython Notebook could not be loaded.") + + require_js = """ + if (window['d3'] === undefined) {{ + require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }}); + require(["d3"], function(d3) {{ + window.d3 = d3; + {0} + }}); + }}; + if (window['topojson'] === undefined) {{ + require.config( + {{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }} + ); + require(["topojson"], function(topojson) {{ + window.topojson = topojson; + }}); + }}; + """ + d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js" + d3_layout_cloud_js_url = "http://wrobstory.github.io/d3-cloud/d3.layout.cloud.js" + topojson_js_url = "http://d3js.org/topojson.v1.min.js" + vega_js_url = "http://trifacta.github.com/vega/vega.js" + + dep_libs = """$.getScript("%s", function() { + $.getScript("%s", function() { + $.getScript("%s", function() { + $.getScript("%s", function() { + $([IPython.events]).trigger("vega_loaded.vincent"); + }) + }) + }) + });""" % ( + d3_geo_projection_js_url, + d3_layout_cloud_js_url, + topojson_js_url, + vega_js_url, + ) + load_js = require_js.format(dep_libs) + html = "<script>" + load_js + "</script>" + display(HTML(html)) diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py new file mode 100644 index 0000000..646e11e --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py @@ -0,0 +1,382 @@ +""" +Utility Routines for Working with Matplotlib Objects +==================================================== +""" + +import itertools +import io +import base64 + +import numpy as np + +import warnings + +import matplotlib +from matplotlib.colors import colorConverter +from matplotlib.path import Path +from matplotlib.markers import MarkerStyle +from matplotlib.transforms import Affine2D +from matplotlib import ticker + + +def export_color(color): + """Convert matplotlib color code to hex color or RGBA color""" + if color is None or colorConverter.to_rgba(color)[3] == 0: + return "none" + elif colorConverter.to_rgba(color)[3] == 1: + rgb = colorConverter.to_rgb(color) + return "#{0:02X}{1:02X}{2:02X}".format(*(int(255 * c) for c in rgb)) + else: + c = colorConverter.to_rgba(color) + return ( + "rgba(" + + ", ".join(str(int(np.round(val * 255))) for val in c[:3]) + + ", " + + str(c[3]) + + ")" + ) + + +def _many_to_one(input_dict): + """Convert a many-to-one mapping to a one-to-one mapping""" + return dict((key, val) for keys, val in input_dict.items() for key in keys) + + +LINESTYLES = _many_to_one( + { + ("solid", "-", (None, None)): "none", + ("dashed", "--"): "6,6", + ("dotted", ":"): "2,2", + ("dashdot", "-."): "4,4,2,4", + ("", " ", "None", "none"): None, + } +) + + +def get_dasharray(obj): + """Get an SVG dash array for the given matplotlib linestyle + + Parameters + ---------- + obj : matplotlib object + The matplotlib line or path object, which must have a get_linestyle() + method which returns a valid matplotlib line code + + Returns + ------- + dasharray : string + The HTML/SVG dasharray code associated with the object. + """ + if obj.__dict__.get("_dashSeq", None) is not None: + return ",".join(map(str, obj._dashSeq)) + else: + ls = obj.get_linestyle() + dasharray = LINESTYLES.get(ls, "not found") + if dasharray == "not found": + warnings.warn( + "line style '{0}' not understood: defaulting to solid line.".format(ls) + ) + dasharray = LINESTYLES["solid"] + return dasharray + + +PATH_DICT = { + Path.LINETO: "L", + Path.MOVETO: "M", + Path.CURVE3: "S", + Path.CURVE4: "C", + Path.CLOSEPOLY: "Z", +} + + +def SVG_path(path, transform=None, simplify=False): + """Construct the vertices and SVG codes for the path + + Parameters + ---------- + path : matplotlib.Path object + + transform : matplotlib transform (optional) + if specified, the path will be transformed before computing the output. + + Returns + ------- + vertices : array + The shape (M, 2) array of vertices of the Path. Note that some Path + codes require multiple vertices, so the length of these vertices may + be longer than the list of path codes. + path_codes : list + A length N list of single-character path codes, N <= M. Each code is + a single character, in ['L','M','S','C','Z']. See the standard SVG + path specification for a description of these. + """ + if transform is not None: + path = path.transformed(transform) + + vc_tuples = [ + (vertices if path_code != Path.CLOSEPOLY else [], PATH_DICT[path_code]) + for (vertices, path_code) in path.iter_segments(simplify=simplify) + ] + + if not vc_tuples: + # empty path is a special case + return np.zeros((0, 2)), [] + else: + vertices, codes = zip(*vc_tuples) + vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2) + return vertices, list(codes) + + +def get_path_style(path, fill=True): + """Get the style dictionary for matplotlib path objects""" + style = {} + style["alpha"] = path.get_alpha() + if style["alpha"] is None: + style["alpha"] = 1 + style["edgecolor"] = export_color(path.get_edgecolor()) + if fill: + style["facecolor"] = export_color(path.get_facecolor()) + else: + style["facecolor"] = "none" + style["edgewidth"] = path.get_linewidth() + style["dasharray"] = get_dasharray(path) + style["zorder"] = path.get_zorder() + return style + + +def get_line_style(line): + """Get the style dictionary for matplotlib line objects""" + style = {} + style["alpha"] = line.get_alpha() + if style["alpha"] is None: + style["alpha"] = 1 + style["color"] = export_color(line.get_color()) + style["linewidth"] = line.get_linewidth() + style["dasharray"] = get_dasharray(line) + style["zorder"] = line.get_zorder() + style["drawstyle"] = line.get_drawstyle() + return style + + +def get_marker_style(line): + """Get the style dictionary for matplotlib marker objects""" + style = {} + style["alpha"] = line.get_alpha() + if style["alpha"] is None: + style["alpha"] = 1 + + style["facecolor"] = export_color(line.get_markerfacecolor()) + style["edgecolor"] = export_color(line.get_markeredgecolor()) + style["edgewidth"] = line.get_markeredgewidth() + + style["marker"] = line.get_marker() + markerstyle = MarkerStyle(line.get_marker()) + markersize = line.get_markersize() + markertransform = markerstyle.get_transform() + Affine2D().scale( + markersize, -markersize + ) + style["markerpath"] = SVG_path(markerstyle.get_path(), markertransform) + style["markersize"] = markersize + style["zorder"] = line.get_zorder() + return style + + +def get_text_style(text): + """Return the text style dict for a text instance""" + style = {} + style["alpha"] = text.get_alpha() + if style["alpha"] is None: + style["alpha"] = 1 + style["fontsize"] = text.get_size() + style["color"] = export_color(text.get_color()) + style["halign"] = text.get_horizontalalignment() # left, center, right + style["valign"] = text.get_verticalalignment() # baseline, center, top + style["malign"] = text._multialignment # text alignment when '\n' in text + style["rotation"] = text.get_rotation() + style["zorder"] = text.get_zorder() + return style + + +def get_axis_properties(axis): + """Return the property dictionary for a matplotlib.Axis instance""" + props = {} + label1On = axis._major_tick_kw.get("label1On", True) + + if isinstance(axis, matplotlib.axis.XAxis): + if label1On: + props["position"] = "bottom" + else: + props["position"] = "top" + elif isinstance(axis, matplotlib.axis.YAxis): + if label1On: + props["position"] = "left" + else: + props["position"] = "right" + else: + raise ValueError("{0} should be an Axis instance".format(axis)) + + # Use tick values if appropriate + locator = axis.get_major_locator() + props["nticks"] = len(locator()) + if isinstance(locator, ticker.FixedLocator): + props["tickvalues"] = list(locator()) + else: + props["tickvalues"] = None + + # Find tick formats + formatter = axis.get_major_formatter() + if isinstance(formatter, ticker.NullFormatter): + props["tickformat"] = "" + elif isinstance(formatter, ticker.FixedFormatter): + props["tickformat"] = list(formatter.seq) + elif isinstance(formatter, ticker.FuncFormatter): + props["tickformat"] = list(formatter.func.args[0].values()) + elif not any(label.get_visible() for label in axis.get_ticklabels()): + props["tickformat"] = "" + else: + props["tickformat"] = None + + # Get axis scale + props["scale"] = axis.get_scale() + + # Get major tick label size (assumes that's all we really care about!) + labels = axis.get_ticklabels() + if labels: + props["fontsize"] = labels[0].get_fontsize() + else: + props["fontsize"] = None + + # Get associated grid + props["grid"] = get_grid_style(axis) + + # get axis visibility + props["visible"] = axis.get_visible() + + return props + + +def get_grid_style(axis): + gridlines = axis.get_gridlines() + if axis._major_tick_kw["gridOn"] and len(gridlines) > 0: + color = export_color(gridlines[0].get_color()) + alpha = gridlines[0].get_alpha() + dasharray = get_dasharray(gridlines[0]) + return dict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha) + else: + return {"gridOn": False} + + +def get_figure_properties(fig): + return { + "figwidth": fig.get_figwidth(), + "figheight": fig.get_figheight(), + "dpi": fig.dpi, + } + + +def get_axes_properties(ax): + props = { + "axesbg": export_color(ax.patch.get_facecolor()), + "axesbgalpha": ax.patch.get_alpha(), + "bounds": ax.get_position().bounds, + "dynamic": ax.get_navigate(), + "axison": ax.axison, + "frame_on": ax.get_frame_on(), + "patch_visible": ax.patch.get_visible(), + "axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)], + } + + for axname in ["x", "y"]: + axis = getattr(ax, axname + "axis") + domain = getattr(ax, "get_{0}lim".format(axname))() + lim = domain + if isinstance(axis.converter, matplotlib.dates.DateConverter): + scale = "date" + try: + import pandas as pd + from pandas.tseries.converter import PeriodConverter + except ImportError: + pd = None + + if pd is not None and isinstance(axis.converter, PeriodConverter): + _dates = [pd.Period(ordinal=int(d), freq=axis.freq) for d in domain] + domain = [ + (d.year, d.month - 1, d.day, d.hour, d.minute, d.second, 0) + for d in _dates + ] + else: + domain = [ + ( + d.year, + d.month - 1, + d.day, + d.hour, + d.minute, + d.second, + d.microsecond * 1e-3, + ) + for d in matplotlib.dates.num2date(domain) + ] + else: + scale = axis.get_scale() + + if scale not in ["date", "linear", "log"]: + raise ValueError("Unknown axis scale: {0}".format(axis.get_scale())) + + props[axname + "scale"] = scale + props[axname + "lim"] = lim + props[axname + "domain"] = domain + + return props + + +def iter_all_children(obj, skipContainers=False): + """ + Returns an iterator over all childen and nested children using + obj's get_children() method + + if skipContainers is true, only childless objects are returned. + """ + if hasattr(obj, "get_children") and len(obj.get_children()) > 0: + for child in obj.get_children(): + if not skipContainers: + yield child + # could use `yield from` in python 3... + for grandchild in iter_all_children(child, skipContainers): + yield grandchild + else: + yield obj + + +def get_legend_properties(ax, legend): + handles, labels = ax.get_legend_handles_labels() + visible = legend.get_visible() + return {"handles": handles, "labels": labels, "visible": visible} + + +def image_to_base64(image): + """ + Convert a matplotlib image to a base64 png representation + + Parameters + ---------- + image : matplotlib image object + The image to be converted. + + Returns + ------- + image_base64 : string + The UTF8-encoded base64 string representation of the png image. + """ + ax = image.axes + binary_buffer = io.BytesIO() + + # image is saved in axes coordinates: we need to temporarily + # set the correct limits to get the correct image + lim = ax.axis() + ax.axis(image.get_extent()) + image.write_png(binary_buffer) + ax.axis(lim) + + binary_buffer.seek(0) + return base64.b64encode(binary_buffer.read()).decode("utf-8") |