1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
|
"""
Matplotlib Exporter
===================
This submodule contains tools for crawling a matplotlib figure and exporting
relevant pieces to a renderer.
"""
import warnings
import io
from . import utils
import matplotlib
from matplotlib import transforms
from matplotlib.backends.backend_agg import FigureCanvasAgg
class Exporter(object):
"""Matplotlib Exporter
Parameters
----------
renderer : Renderer object
The renderer object called by the exporter to create a figure
visualization. See mplexporter.Renderer for information on the
methods which should be defined within the renderer.
close_mpl : bool
If True (default), close the matplotlib figure as it is rendered. This
is useful for when the exporter is used within the notebook, or with
an interactive matplotlib backend.
"""
def __init__(self, renderer, close_mpl=True):
self.close_mpl = close_mpl
self.renderer = renderer
def run(self, fig):
"""
Run the exporter on the given figure
Parmeters
---------
fig : matplotlib.Figure instance
The figure to export
"""
# Calling savefig executes the draw() command, putting elements
# in the correct place.
if fig.canvas is None:
FigureCanvasAgg(fig)
fig.savefig(io.BytesIO(), format="png", dpi=fig.dpi)
if self.close_mpl:
import matplotlib.pyplot as plt
plt.close(fig)
self.crawl_fig(fig)
@staticmethod
def process_transform(
transform, ax=None, data=None, return_trans=False, force_trans=None
):
"""Process the transform and convert data to figure or data coordinates
Parameters
----------
transform : matplotlib Transform object
The transform applied to the data
ax : matplotlib Axes object (optional)
The axes the data is associated with
data : ndarray (optional)
The array of data to be transformed.
return_trans : bool (optional)
If true, return the final transform of the data
force_trans : matplotlib.transform instance (optional)
If supplied, first force the data to this transform
Returns
-------
code : string
Code is either "data", "axes", "figure", or "display", indicating
the type of coordinates output.
transform : matplotlib transform
the transform used to map input data to output data.
Returned only if return_trans is True
new_data : ndarray
Data transformed to match the given coordinate code.
Returned only if data is specified
"""
if isinstance(transform, transforms.BlendedGenericTransform):
warnings.warn(
"Blended transforms not yet supported. "
"Zoom behavior may not work as expected."
)
if force_trans is not None:
if data is not None:
data = (transform - force_trans).transform(data)
transform = force_trans
code = "display"
if ax is not None:
for c, trans in [
("data", ax.transData),
("axes", ax.transAxes),
("figure", ax.figure.transFigure),
("display", transforms.IdentityTransform()),
]:
if transform.contains_branch(trans):
code, transform = (c, transform - trans)
break
if data is not None:
if return_trans:
return code, transform.transform(data), transform
else:
return code, transform.transform(data)
else:
if return_trans:
return code, transform
else:
return code
def crawl_fig(self, fig):
"""Crawl the figure and process all axes"""
with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)):
for ax in fig.axes:
self.crawl_ax(ax)
def crawl_ax(self, ax):
"""Crawl the axes and process all elements within"""
with self.renderer.draw_axes(ax=ax, props=utils.get_axes_properties(ax)):
for line in ax.lines:
self.draw_line(ax, line)
for text in ax.texts:
self.draw_text(ax, text)
for text, ttp in zip(
[ax.xaxis.label, ax.yaxis.label, ax.title],
["xlabel", "ylabel", "title"],
):
if hasattr(text, "get_text") and text.get_text():
self.draw_text(ax, text, force_trans=ax.transAxes, text_type=ttp)
for artist in ax.artists:
# TODO: process other artists
if isinstance(artist, matplotlib.text.Text):
self.draw_text(ax, artist)
for patch in ax.patches:
self.draw_patch(ax, patch)
for collection in ax.collections:
self.draw_collection(ax, collection)
for image in ax.images:
self.draw_image(ax, image)
legend = ax.get_legend()
if legend is not None:
props = utils.get_legend_properties(ax, legend)
with self.renderer.draw_legend(legend=legend, props=props):
if props["visible"]:
self.crawl_legend(ax, legend)
def crawl_legend(self, ax, legend):
"""
Recursively look through objects in legend children
"""
legendElements = list(
utils.iter_all_children(legend._legend_box, skipContainers=True)
)
legendElements.append(legend.legendPatch)
for child in legendElements:
# force a large zorder so it appears on top
child.set_zorder(1e6 + child.get_zorder())
# reorder border box to make sure marks are visible
if isinstance(child, matplotlib.patches.FancyBboxPatch):
child.set_zorder(child.get_zorder() - 1)
try:
# What kind of object...
if isinstance(child, matplotlib.patches.Patch):
self.draw_patch(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.text.Text):
if child.get_text() != "None":
self.draw_text(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.lines.Line2D):
self.draw_line(ax, child, force_trans=ax.transAxes)
elif isinstance(child, matplotlib.collections.Collection):
self.draw_collection(ax, child, force_pathtrans=ax.transAxes)
else:
warnings.warn("Legend element %s not impemented" % child)
except NotImplementedError:
warnings.warn("Legend element %s not impemented" % child)
def draw_line(self, ax, line, force_trans=None):
"""Process a matplotlib line and call renderer.draw_line"""
coordinates, data = self.process_transform(
line.get_transform(), ax, line.get_xydata(), force_trans=force_trans
)
linestyle = utils.get_line_style(line)
if linestyle["dasharray"] is None and linestyle["drawstyle"] == "default":
linestyle = None
markerstyle = utils.get_marker_style(line)
if (
markerstyle["marker"] in ["None", "none", None]
or markerstyle["markerpath"][0].size == 0
):
markerstyle = None
label = line.get_label()
if markerstyle or linestyle:
self.renderer.draw_marked_line(
data=data,
coordinates=coordinates,
linestyle=linestyle,
markerstyle=markerstyle,
label=label,
mplobj=line,
)
def draw_text(self, ax, text, force_trans=None, text_type=None):
"""Process a matplotlib text object and call renderer.draw_text"""
content = text.get_text()
if content:
transform = text.get_transform()
position = text.get_position()
coords, position = self.process_transform(
transform, ax, position, force_trans=force_trans
)
style = utils.get_text_style(text)
self.renderer.draw_text(
text=content,
position=position,
coordinates=coords,
text_type=text_type,
style=style,
mplobj=text,
)
def draw_patch(self, ax, patch, force_trans=None):
"""Process a matplotlib patch object and call renderer.draw_path"""
vertices, pathcodes = utils.SVG_path(patch.get_path())
transform = patch.get_transform()
coordinates, vertices = self.process_transform(
transform, ax, vertices, force_trans=force_trans
)
linestyle = utils.get_path_style(patch, fill=patch.get_fill())
self.renderer.draw_path(
data=vertices,
coordinates=coordinates,
pathcodes=pathcodes,
style=linestyle,
mplobj=patch,
)
def draw_collection(
self, ax, collection, force_pathtrans=None, force_offsettrans=None
):
"""Process a matplotlib collection and call renderer.draw_collection"""
(transform, transOffset, offsets, paths) = collection._prepare_points()
offset_coords, offsets = self.process_transform(
transOffset, ax, offsets, force_trans=force_offsettrans
)
path_coords = self.process_transform(transform, ax, force_trans=force_pathtrans)
processed_paths = [utils.SVG_path(path) for path in paths]
processed_paths = [
(
self.process_transform(
transform, ax, path[0], force_trans=force_pathtrans
)[1],
path[1],
)
for path in processed_paths
]
path_transforms = collection.get_transforms()
try:
# matplotlib 1.3: path_transforms are transform objects.
# Convert them to numpy arrays.
path_transforms = [t.get_matrix() for t in path_transforms]
except AttributeError:
# matplotlib 1.4: path transforms are already numpy arrays.
pass
styles = {
"linewidth": collection.get_linewidths(),
"facecolor": collection.get_facecolors(),
"edgecolor": collection.get_edgecolors(),
"alpha": collection._alpha,
"zorder": collection.get_zorder(),
}
# TODO: When matplotlib's minimum version is bumped to 3.8, this can be
# simplified since collection.get_offset_position no longer exists.
offset_dict = {"data": "before", "screen": "after"}
offset_order = (
offset_dict[collection.get_offset_position()]
if hasattr(collection, "get_offset_position")
else "after"
)
self.renderer.draw_path_collection(
paths=processed_paths,
path_coordinates=path_coords,
path_transforms=path_transforms,
offsets=offsets,
offset_coordinates=offset_coords,
offset_order=offset_order,
styles=styles,
mplobj=collection,
)
def draw_image(self, ax, image):
"""Process a matplotlib image object and call renderer.draw_image"""
self.renderer.draw_image(
imdata=utils.image_to_base64(image),
extent=image.get_extent(),
coordinates="data",
style={"alpha": image.get_alpha(), "zorder": image.get_zorder()},
mplobj=image,
)
|