aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/plotly/figure_factory
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/figure_factory')
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py155
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py69
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py307
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py366
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py277
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py1013
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py395
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py441
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py1195
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py1034
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py526
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py295
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py265
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py1135
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py406
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py280
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py692
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py509
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py704
-rw-r--r--venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py249
20 files changed, 10313 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py
new file mode 100644
index 0000000..3094d0b
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py
@@ -0,0 +1,155 @@
+from numbers import Number
+
+import plotly.exceptions
+
+import plotly.colors as clrs
+from plotly.graph_objs import graph_objs
+
+
+def make_linear_colorscale(colors):
+ """
+ Makes a list of colors into a colorscale-acceptable form
+
+ For documentation regarding to the form of the output, see
+ https://plot.ly/python/reference/#mesh3d-colorscale
+ """
+ scale = 1.0 / (len(colors) - 1)
+ return [[i * scale, color] for i, color in enumerate(colors)]
+
+
+def create_2d_density(
+ x,
+ y,
+ colorscale="Earth",
+ ncontours=20,
+ hist_color=(0, 0, 0.5),
+ point_color=(0, 0, 0.5),
+ point_size=2,
+ title="2D Density Plot",
+ height=600,
+ width=600,
+):
+ """
+ **deprecated**, use instead
+ :func:`plotly.express.density_heatmap`.
+
+ :param (list|array) x: x-axis data for plot generation
+ :param (list|array) y: y-axis data for plot generation
+ :param (str|tuple|list) colorscale: either a plotly scale name, an rgb
+ or hex color, a color tuple or a list or tuple of colors. An rgb
+ color is of the form 'rgb(x, y, z)' where x, y, z belong to the
+ interval [0, 255] and a color tuple is a tuple of the form
+ (a, b, c) where a, b and c belong to [0, 1]. If colormap is a
+ list, it must contain the valid color types aforementioned as its
+ members.
+ :param (int) ncontours: the number of 2D contours to draw on the plot
+ :param (str) hist_color: the color of the plotted histograms
+ :param (str) point_color: the color of the scatter points
+ :param (str) point_size: the color of the scatter points
+ :param (str) title: set the title for the plot
+ :param (float) height: the height of the chart
+ :param (float) width: the width of the chart
+
+ Examples
+ --------
+
+ Example 1: Simple 2D Density Plot
+
+ >>> from plotly.figure_factory import create_2d_density
+ >>> import numpy as np
+
+ >>> # Make data points
+ >>> t = np.linspace(-1,1.2,2000)
+ >>> x = (t**3)+(0.3*np.random.randn(2000))
+ >>> y = (t**6)+(0.3*np.random.randn(2000))
+
+ >>> # Create a figure
+ >>> fig = create_2d_density(x, y)
+
+ >>> # Plot the data
+ >>> fig.show()
+
+ Example 2: Using Parameters
+
+ >>> from plotly.figure_factory import create_2d_density
+
+ >>> import numpy as np
+
+ >>> # Make data points
+ >>> t = np.linspace(-1,1.2,2000)
+ >>> x = (t**3)+(0.3*np.random.randn(2000))
+ >>> y = (t**6)+(0.3*np.random.randn(2000))
+
+ >>> # Create custom colorscale
+ >>> colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
+ ... (1, 1, 0.2), (0.98,0.98,0.98)]
+
+ >>> # Create a figure
+ >>> fig = create_2d_density(x, y, colorscale=colorscale,
+ ... hist_color='rgb(255, 237, 222)', point_size=3)
+
+ >>> # Plot the data
+ >>> fig.show()
+ """
+
+ # validate x and y are filled with numbers only
+ for array in [x, y]:
+ if not all(isinstance(element, Number) for element in array):
+ raise plotly.exceptions.PlotlyError(
+ "All elements of your 'x' and 'y' lists must be numbers."
+ )
+
+ # validate x and y are the same length
+ if len(x) != len(y):
+ raise plotly.exceptions.PlotlyError(
+ "Both lists 'x' and 'y' must be the same length."
+ )
+
+ colorscale = clrs.validate_colors(colorscale, "rgb")
+ colorscale = make_linear_colorscale(colorscale)
+
+ # validate hist_color and point_color
+ hist_color = clrs.validate_colors(hist_color, "rgb")
+ point_color = clrs.validate_colors(point_color, "rgb")
+
+ trace1 = graph_objs.Scatter(
+ x=x,
+ y=y,
+ mode="markers",
+ name="points",
+ marker=dict(color=point_color[0], size=point_size, opacity=0.4),
+ )
+ trace2 = graph_objs.Histogram2dContour(
+ x=x,
+ y=y,
+ name="density",
+ ncontours=ncontours,
+ colorscale=colorscale,
+ reversescale=True,
+ showscale=False,
+ )
+ trace3 = graph_objs.Histogram(
+ x=x, name="x density", marker=dict(color=hist_color[0]), yaxis="y2"
+ )
+ trace4 = graph_objs.Histogram(
+ y=y, name="y density", marker=dict(color=hist_color[0]), xaxis="x2"
+ )
+ data = [trace1, trace2, trace3, trace4]
+
+ layout = graph_objs.Layout(
+ showlegend=False,
+ autosize=False,
+ title=title,
+ height=height,
+ width=width,
+ xaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
+ yaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
+ margin=dict(t=50),
+ hovermode="closest",
+ bargap=0,
+ xaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
+ yaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
+ )
+
+ fig = graph_objs.Figure(data=data, layout=layout)
+ return fig
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py
new file mode 100644
index 0000000..1919ca8
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py
@@ -0,0 +1,69 @@
+# ruff: noqa: E402
+
+from plotly import optional_imports
+
+# Require that numpy exists for figure_factory
+np = optional_imports.get_module("numpy")
+if np is None:
+ raise ImportError(
+ """\
+The figure factory module requires the numpy package"""
+ )
+
+
+from plotly.figure_factory._2d_density import create_2d_density
+from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap
+from plotly.figure_factory._bullet import create_bullet
+from plotly.figure_factory._candlestick import create_candlestick
+from plotly.figure_factory._dendrogram import create_dendrogram
+from plotly.figure_factory._distplot import create_distplot
+from plotly.figure_factory._facet_grid import create_facet_grid
+from plotly.figure_factory._gantt import create_gantt
+from plotly.figure_factory._ohlc import create_ohlc
+from plotly.figure_factory._quiver import create_quiver
+from plotly.figure_factory._scatterplot import create_scatterplotmatrix
+from plotly.figure_factory._streamline import create_streamline
+from plotly.figure_factory._table import create_table
+from plotly.figure_factory._trisurf import create_trisurf
+from plotly.figure_factory._violin import create_violin
+
+if optional_imports.get_module("pandas") is not None:
+ from plotly.figure_factory._county_choropleth import create_choropleth
+ from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox
+else:
+
+ def create_choropleth(*args, **kwargs):
+ raise ImportError("Please install pandas to use `create_choropleth`")
+
+ def create_hexbin_mapbox(*args, **kwargs):
+ raise ImportError("Please install pandas to use `create_hexbin_mapbox`")
+
+
+if optional_imports.get_module("skimage") is not None:
+ from plotly.figure_factory._ternary_contour import create_ternary_contour
+else:
+
+ def create_ternary_contour(*args, **kwargs):
+ raise ImportError("Please install scikit-image to use `create_ternary_contour`")
+
+
+__all__ = [
+ "create_2d_density",
+ "create_annotated_heatmap",
+ "create_bullet",
+ "create_candlestick",
+ "create_choropleth",
+ "create_dendrogram",
+ "create_distplot",
+ "create_facet_grid",
+ "create_gantt",
+ "create_hexbin_mapbox",
+ "create_ohlc",
+ "create_quiver",
+ "create_scatterplotmatrix",
+ "create_streamline",
+ "create_table",
+ "create_ternary_contour",
+ "create_trisurf",
+ "create_violin",
+]
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py
new file mode 100644
index 0000000..5da24ae
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py
@@ -0,0 +1,307 @@
+import plotly.colors as clrs
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.validator_cache import ValidatorCache
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module("numpy")
+
+
+def validate_annotated_heatmap(z, x, y, annotation_text):
+ """
+ Annotated-heatmap-specific validations
+
+ Check that if a text matrix is supplied, it has the same
+ dimensions as the z matrix.
+
+ See FigureFactory.create_annotated_heatmap() for params
+
+ :raises: (PlotlyError) If z and text matrices do not have the same
+ dimensions.
+ """
+ if annotation_text is not None and isinstance(annotation_text, list):
+ utils.validate_equal_length(z, annotation_text)
+ for lst in range(len(z)):
+ if len(z[lst]) != len(annotation_text[lst]):
+ raise exceptions.PlotlyError(
+ "z and text should have the same dimensions"
+ )
+
+ if x:
+ if len(x) != len(z[0]):
+ raise exceptions.PlotlyError(
+ "oops, the x list that you "
+ "provided does not match the "
+ "width of your z matrix "
+ )
+
+ if y:
+ if len(y) != len(z):
+ raise exceptions.PlotlyError(
+ "oops, the y list that you "
+ "provided does not match the "
+ "length of your z matrix "
+ )
+
+
+def create_annotated_heatmap(
+ z,
+ x=None,
+ y=None,
+ annotation_text=None,
+ colorscale="Plasma",
+ font_colors=None,
+ showscale=False,
+ reversescale=False,
+ **kwargs,
+):
+ """
+ **deprecated**, use instead
+ :func:`plotly.express.imshow`.
+
+ Function that creates annotated heatmaps
+
+ This function adds annotations to each cell of the heatmap.
+
+ :param (list[list]|ndarray) z: z matrix to create heatmap.
+ :param (list) x: x axis labels.
+ :param (list) y: y axis labels.
+ :param (list[list]|ndarray) annotation_text: Text strings for
+ annotations. Should have the same dimensions as the z matrix. If no
+ text is added, the values of the z matrix are annotated. Default =
+ z matrix values.
+ :param (list|str) colorscale: heatmap colorscale.
+ :param (list) font_colors: List of two color strings: [min_text_color,
+ max_text_color] where min_text_color is applied to annotations for
+ heatmap values < (max_value - min_value)/2. If font_colors is not
+ defined, the colors are defined logically as black or white
+ depending on the heatmap's colorscale.
+ :param (bool) showscale: Display colorscale. Default = False
+ :param (bool) reversescale: Reverse colorscale. Default = False
+ :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
+ These kwargs describe other attributes about the annotated Heatmap
+ trace such as the colorscale. For more information on valid kwargs
+ call help(plotly.graph_objs.Heatmap)
+
+ Example 1: Simple annotated heatmap with default configuration
+
+ >>> import plotly.figure_factory as ff
+
+ >>> z = [[0.300000, 0.00000, 0.65, 0.300000],
+ ... [1, 0.100005, 0.45, 0.4300],
+ ... [0.300000, 0.00000, 0.65, 0.300000],
+ ... [1, 0.100005, 0.45, 0.00000]]
+
+ >>> fig = ff.create_annotated_heatmap(z)
+ >>> fig.show()
+ """
+
+ # Avoiding mutables in the call signature
+ font_colors = font_colors if font_colors is not None else []
+ validate_annotated_heatmap(z, x, y, annotation_text)
+
+ # validate colorscale
+ colorscale_validator = ValidatorCache.get_validator("heatmap", "colorscale")
+ colorscale = colorscale_validator.validate_coerce(colorscale)
+
+ annotations = _AnnotatedHeatmap(
+ z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
+ ).make_annotations()
+
+ if x or y:
+ trace = dict(
+ type="heatmap",
+ z=z,
+ x=x,
+ y=y,
+ colorscale=colorscale,
+ showscale=showscale,
+ reversescale=reversescale,
+ **kwargs,
+ )
+ layout = dict(
+ annotations=annotations,
+ xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"),
+ yaxis=dict(ticks="", dtick=1, ticksuffix=" "),
+ )
+ else:
+ trace = dict(
+ type="heatmap",
+ z=z,
+ colorscale=colorscale,
+ showscale=showscale,
+ reversescale=reversescale,
+ **kwargs,
+ )
+ layout = dict(
+ annotations=annotations,
+ xaxis=dict(
+ ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False
+ ),
+ yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False),
+ )
+
+ data = [trace]
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+def to_rgb_color_list(color_str, default):
+ color_str = color_str.strip()
+ if color_str.startswith("rgb"):
+ return [int(v) for v in color_str.strip("rgba()").split(",")]
+ elif color_str.startswith("#"):
+ return clrs.hex_to_rgb(color_str)
+ else:
+ return default
+
+
+def should_use_black_text(background_color):
+ return (
+ background_color[0] * 0.299
+ + background_color[1] * 0.587
+ + background_color[2] * 0.114
+ ) > 186
+
+
+class _AnnotatedHeatmap(object):
+ """
+ Refer to TraceFactory.create_annotated_heatmap() for docstring
+ """
+
+ def __init__(
+ self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
+ ):
+ self.z = z
+ if x:
+ self.x = x
+ else:
+ self.x = range(len(z[0]))
+ if y:
+ self.y = y
+ else:
+ self.y = range(len(z))
+ if annotation_text is not None:
+ self.annotation_text = annotation_text
+ else:
+ self.annotation_text = self.z
+ self.colorscale = colorscale
+ self.reversescale = reversescale
+ self.font_colors = font_colors
+
+ if np and isinstance(self.z, np.ndarray):
+ self.zmin = np.amin(self.z)
+ self.zmax = np.amax(self.z)
+ else:
+ self.zmin = min([v for row in self.z for v in row])
+ self.zmax = max([v for row in self.z for v in row])
+
+ if kwargs.get("zmin", None) is not None:
+ self.zmin = kwargs["zmin"]
+ if kwargs.get("zmax", None) is not None:
+ self.zmax = kwargs["zmax"]
+
+ self.zmid = (self.zmax + self.zmin) / 2
+
+ if kwargs.get("zmid", None) is not None:
+ self.zmid = kwargs["zmid"]
+
+ def get_text_color(self):
+ """
+ Get font color for annotations.
+
+ The annotated heatmap can feature two text colors: min_text_color and
+ max_text_color. The min_text_color is applied to annotations for
+ heatmap values < (max_value - min_value)/2. The user can define these
+ two colors. Otherwise the colors are defined logically as black or
+ white depending on the heatmap's colorscale.
+
+ :rtype (string, string) min_text_color, max_text_color: text
+ color for annotations for heatmap values <
+ (max_value - min_value)/2 and text color for annotations for
+ heatmap values >= (max_value - min_value)/2
+ """
+ # Plotly colorscales ranging from a lighter shade to a darker shade
+ colorscales = [
+ "Greys",
+ "Greens",
+ "Blues",
+ "YIGnBu",
+ "YIOrRd",
+ "RdBu",
+ "Picnic",
+ "Jet",
+ "Hot",
+ "Blackbody",
+ "Earth",
+ "Electric",
+ "Viridis",
+ "Cividis",
+ ]
+ # Plotly colorscales ranging from a darker shade to a lighter shade
+ colorscales_reverse = ["Reds"]
+
+ white = "#FFFFFF"
+ black = "#000000"
+ if self.font_colors:
+ min_text_color = self.font_colors[0]
+ max_text_color = self.font_colors[-1]
+ elif self.colorscale in colorscales and self.reversescale:
+ min_text_color = black
+ max_text_color = white
+ elif self.colorscale in colorscales:
+ min_text_color = white
+ max_text_color = black
+ elif self.colorscale in colorscales_reverse and self.reversescale:
+ min_text_color = white
+ max_text_color = black
+ elif self.colorscale in colorscales_reverse:
+ min_text_color = black
+ max_text_color = white
+ elif isinstance(self.colorscale, list):
+ min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255])
+ max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255])
+
+ # swap min/max colors if reverse scale
+ if self.reversescale:
+ min_col, max_col = max_col, min_col
+
+ if should_use_black_text(min_col):
+ min_text_color = black
+ else:
+ min_text_color = white
+
+ if should_use_black_text(max_col):
+ max_text_color = black
+ else:
+ max_text_color = white
+ else:
+ min_text_color = black
+ max_text_color = black
+ return min_text_color, max_text_color
+
+ def make_annotations(self):
+ """
+ Get annotations for each cell of the heatmap with graph_objs.Annotation
+
+ :rtype (list[dict]) annotations: list of annotations for each cell of
+ the heatmap
+ """
+ min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
+ annotations = []
+ for n, row in enumerate(self.z):
+ for m, val in enumerate(row):
+ font_color = min_text_color if val < self.zmid else max_text_color
+ annotations.append(
+ graph_objs.layout.Annotation(
+ text=str(self.annotation_text[n][m]),
+ x=self.x[m],
+ y=self.y[n],
+ xref="x1",
+ yref="y1",
+ font=dict(color=font_color),
+ showarrow=False,
+ )
+ )
+ return annotations
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py
new file mode 100644
index 0000000..ce51e93
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py
@@ -0,0 +1,366 @@
+import math
+
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.figure_factory import utils
+
+import plotly
+import plotly.graph_objs as go
+
+pd = optional_imports.get_module("pandas")
+
+
+def _bullet(
+ df,
+ markers,
+ measures,
+ ranges,
+ subtitles,
+ titles,
+ orientation,
+ range_colors,
+ measure_colors,
+ horizontal_spacing,
+ vertical_spacing,
+ scatter_options,
+ layout_options,
+):
+ num_of_lanes = len(df)
+ num_of_rows = num_of_lanes if orientation == "h" else 1
+ num_of_cols = 1 if orientation == "h" else num_of_lanes
+ if not horizontal_spacing:
+ horizontal_spacing = 1.0 / num_of_lanes
+ if not vertical_spacing:
+ vertical_spacing = 1.0 / num_of_lanes
+ fig = plotly.subplots.make_subplots(
+ num_of_rows,
+ num_of_cols,
+ print_grid=False,
+ horizontal_spacing=horizontal_spacing,
+ vertical_spacing=vertical_spacing,
+ )
+
+ # layout
+ fig["layout"].update(
+ dict(shapes=[]),
+ title="Bullet Chart",
+ height=600,
+ width=1000,
+ showlegend=False,
+ barmode="stack",
+ annotations=[],
+ margin=dict(l=120 if orientation == "h" else 80),
+ )
+
+ # update layout
+ fig["layout"].update(layout_options)
+
+ if orientation == "h":
+ width_axis = "yaxis"
+ length_axis = "xaxis"
+ else:
+ width_axis = "xaxis"
+ length_axis = "yaxis"
+
+ for key in fig["layout"]:
+ if "xaxis" in key or "yaxis" in key:
+ fig["layout"][key]["showgrid"] = False
+ fig["layout"][key]["zeroline"] = False
+ if length_axis in key:
+ fig["layout"][key]["tickwidth"] = 1
+ if width_axis in key:
+ fig["layout"][key]["showticklabels"] = False
+ fig["layout"][key]["range"] = [0, 1]
+
+ # narrow domain if 1 bar
+ if num_of_lanes <= 1:
+ fig["layout"][width_axis + "1"]["domain"] = [0.4, 0.6]
+
+ if not range_colors:
+ range_colors = ["rgb(200, 200, 200)", "rgb(245, 245, 245)"]
+ if not measure_colors:
+ measure_colors = ["rgb(31, 119, 180)", "rgb(176, 196, 221)"]
+
+ for row in range(num_of_lanes):
+ # ranges bars
+ for idx in range(len(df.iloc[row]["ranges"])):
+ inter_colors = clrs.n_colors(
+ range_colors[0], range_colors[1], len(df.iloc[row]["ranges"]), "rgb"
+ )
+ x = (
+ [sorted(df.iloc[row]["ranges"])[-1 - idx]]
+ if orientation == "h"
+ else [0]
+ )
+ y = (
+ [0]
+ if orientation == "h"
+ else [sorted(df.iloc[row]["ranges"])[-1 - idx]]
+ )
+ bar = go.Bar(
+ x=x,
+ y=y,
+ marker=dict(color=inter_colors[-1 - idx]),
+ name="ranges",
+ hoverinfo="x" if orientation == "h" else "y",
+ orientation=orientation,
+ width=2,
+ base=0,
+ xaxis="x{}".format(row + 1),
+ yaxis="y{}".format(row + 1),
+ )
+ fig.add_trace(bar)
+
+ # measures bars
+ for idx in range(len(df.iloc[row]["measures"])):
+ inter_colors = clrs.n_colors(
+ measure_colors[0],
+ measure_colors[1],
+ len(df.iloc[row]["measures"]),
+ "rgb",
+ )
+ x = (
+ [sorted(df.iloc[row]["measures"])[-1 - idx]]
+ if orientation == "h"
+ else [0.5]
+ )
+ y = (
+ [0.5]
+ if orientation == "h"
+ else [sorted(df.iloc[row]["measures"])[-1 - idx]]
+ )
+ bar = go.Bar(
+ x=x,
+ y=y,
+ marker=dict(color=inter_colors[-1 - idx]),
+ name="measures",
+ hoverinfo="x" if orientation == "h" else "y",
+ orientation=orientation,
+ width=0.4,
+ base=0,
+ xaxis="x{}".format(row + 1),
+ yaxis="y{}".format(row + 1),
+ )
+ fig.add_trace(bar)
+
+ # markers
+ x = df.iloc[row]["markers"] if orientation == "h" else [0.5]
+ y = [0.5] if orientation == "h" else df.iloc[row]["markers"]
+ markers = go.Scatter(
+ x=x,
+ y=y,
+ name="markers",
+ hoverinfo="x" if orientation == "h" else "y",
+ xaxis="x{}".format(row + 1),
+ yaxis="y{}".format(row + 1),
+ **scatter_options,
+ )
+
+ fig.add_trace(markers)
+
+ # titles and subtitles
+ title = df.iloc[row]["titles"]
+ if "subtitles" in df:
+ subtitle = "<br>{}".format(df.iloc[row]["subtitles"])
+ else:
+ subtitle = ""
+ label = "<b>{}</b>".format(title) + subtitle
+ annot = utils.annotation_dict_for_label(
+ label,
+ (num_of_lanes - row if orientation == "h" else row + 1),
+ num_of_lanes,
+ vertical_spacing if orientation == "h" else horizontal_spacing,
+ "row" if orientation == "h" else "col",
+ True if orientation == "h" else False,
+ False,
+ )
+ fig["layout"]["annotations"] += (annot,)
+
+ return fig
+
+
+def create_bullet(
+ data,
+ markers=None,
+ measures=None,
+ ranges=None,
+ subtitles=None,
+ titles=None,
+ orientation="h",
+ range_colors=("rgb(200, 200, 200)", "rgb(245, 245, 245)"),
+ measure_colors=("rgb(31, 119, 180)", "rgb(176, 196, 221)"),
+ horizontal_spacing=None,
+ vertical_spacing=None,
+ scatter_options={},
+ **layout_options,
+):
+ """
+ **deprecated**, use instead the plotly.graph_objects trace
+ :class:`plotly.graph_objects.Indicator`.
+
+ :param (pd.DataFrame | list | tuple) data: either a list/tuple of
+ dictionaries or a pandas DataFrame.
+ :param (str) markers: the column name or dictionary key for the markers in
+ each subplot.
+ :param (str) measures: the column name or dictionary key for the measure
+ bars in each subplot. This bar usually represents the quantitative
+ measure of performance, usually a list of two values [a, b] and are
+ the blue bars in the foreground of each subplot by default.
+ :param (str) ranges: the column name or dictionary key for the qualitative
+ ranges of performance, usually a 3-item list [bad, okay, good]. They
+ correspond to the grey bars in the background of each chart.
+ :param (str) subtitles: the column name or dictionary key for the subtitle
+ of each subplot chart. The subplots are displayed right underneath
+ each title.
+ :param (str) titles: the column name or dictionary key for the main label
+ of each subplot chart.
+ :param (bool) orientation: if 'h', the bars are placed horizontally as
+ rows. If 'v' the bars are placed vertically in the chart.
+ :param (list) range_colors: a tuple of two colors between which all
+ the rectangles for the range are drawn. These rectangles are meant to
+ be qualitative indicators against which the marker and measure bars
+ are compared.
+ Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)')
+ :param (list) measure_colors: a tuple of two colors which is used to color
+ the thin quantitative bars in the bullet chart.
+ Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)')
+ :param (float) horizontal_spacing: see the 'horizontal_spacing' param in
+ plotly.tools.make_subplots. Ranges between 0 and 1.
+ :param (float) vertical_spacing: see the 'vertical_spacing' param in
+ plotly.tools.make_subplots. Ranges between 0 and 1.
+ :param (dict) scatter_options: describes attributes for the scatter trace
+ in each subplot such as name and marker size. Call
+ help(plotly.graph_objs.Scatter) for more information on valid params.
+ :param layout_options: describes attributes for the layout of the figure
+ such as title, height and width. Call help(plotly.graph_objs.Layout)
+ for more information on valid params.
+
+ Example 1: Use a Dictionary
+
+ >>> import plotly.figure_factory as ff
+
+ >>> data = [
+ ... {"label": "revenue", "sublabel": "us$, in thousands",
+ ... "range": [150, 225, 300], "performance": [220,270], "point": [250]},
+ ... {"label": "Profit", "sublabel": "%", "range": [20, 25, 30],
+ ... "performance": [21, 23], "point": [26]},
+ ... {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600],
+ ... "performance": [100,320],"point": [550]},
+ ... {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500],
+ ... "performance": [1000, 1650],"point": [2100]},
+ ... {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5],
+ ... "performance": [3.2, 4.7], "point": [4.4]}
+ ... ]
+
+ >>> fig = ff.create_bullet(
+ ... data, titles='label', subtitles='sublabel', markers='point',
+ ... measures='performance', ranges='range', orientation='h',
+ ... title='my simple bullet chart'
+ ... )
+ >>> fig.show()
+
+ Example 2: Use a DataFrame with Custom Colors
+
+ >>> import plotly.figure_factory as ff
+ >>> import pandas as pd
+ >>> data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json')
+
+ >>> fig = ff.create_bullet(
+ ... data, titles='title', markers='markers', measures='measures',
+ ... orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'],
+ ... scatter_options={'marker': {'symbol': 'circle'}}, width=700)
+ >>> fig.show()
+ """
+ # validate df
+ if not pd:
+ raise ImportError("'pandas' must be installed for this figure factory.")
+
+ if utils.is_sequence(data):
+ if not all(isinstance(item, dict) for item in data):
+ raise exceptions.PlotlyError(
+ "Every entry of the data argument list, tuple, etc must "
+ "be a dictionary."
+ )
+
+ elif not isinstance(data, pd.DataFrame):
+ raise exceptions.PlotlyError(
+ "You must input a pandas DataFrame, or a list of dictionaries."
+ )
+
+ # make DataFrame from data with correct column headers
+ col_names = ["titles", "subtitle", "markers", "measures", "ranges"]
+ if utils.is_sequence(data):
+ df = pd.DataFrame(
+ [
+ [d[titles] for d in data] if titles else [""] * len(data),
+ [d[subtitles] for d in data] if subtitles else [""] * len(data),
+ [d[markers] for d in data] if markers else [[]] * len(data),
+ [d[measures] for d in data] if measures else [[]] * len(data),
+ [d[ranges] for d in data] if ranges else [[]] * len(data),
+ ],
+ index=col_names,
+ )
+ elif isinstance(data, pd.DataFrame):
+ df = pd.DataFrame(
+ [
+ data[titles].tolist() if titles else [""] * len(data),
+ data[subtitles].tolist() if subtitles else [""] * len(data),
+ data[markers].tolist() if markers else [[]] * len(data),
+ data[measures].tolist() if measures else [[]] * len(data),
+ data[ranges].tolist() if ranges else [[]] * len(data),
+ ],
+ index=col_names,
+ )
+ df = pd.DataFrame.transpose(df)
+
+ # make sure ranges, measures, 'markers' are not NAN or NONE
+ for needed_key in ["ranges", "measures", "markers"]:
+ for idx, r in enumerate(df[needed_key]):
+ try:
+ r_is_nan = math.isnan(r)
+ if r_is_nan or r is None:
+ df[needed_key][idx] = []
+ except TypeError:
+ pass
+
+ # validate custom colors
+ for colors_list in [range_colors, measure_colors]:
+ if colors_list:
+ if len(colors_list) != 2:
+ raise exceptions.PlotlyError(
+ "Both 'range_colors' or 'measure_colors' must be a list "
+ "of two valid colors."
+ )
+ clrs.validate_colors(colors_list)
+ colors_list = clrs.convert_colors_to_same_type(colors_list, "rgb")[0]
+
+ # default scatter options
+ default_scatter = {
+ "marker": {"size": 12, "symbol": "diamond-tall", "color": "rgb(0, 0, 0)"}
+ }
+
+ if scatter_options == {}:
+ scatter_options.update(default_scatter)
+ else:
+ # add default options to scatter_options if they are not present
+ for k in default_scatter["marker"]:
+ if k not in scatter_options["marker"]:
+ scatter_options["marker"][k] = default_scatter["marker"][k]
+
+ fig = _bullet(
+ df,
+ markers,
+ measures,
+ ranges,
+ subtitles,
+ titles,
+ orientation,
+ range_colors,
+ measure_colors,
+ horizontal_spacing,
+ vertical_spacing,
+ scatter_options,
+ layout_options,
+ )
+
+ return fig
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py
new file mode 100644
index 0000000..572ccfe
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py
@@ -0,0 +1,277 @@
+from plotly.figure_factory import utils
+from plotly.figure_factory._ohlc import (
+ _DEFAULT_INCREASING_COLOR,
+ _DEFAULT_DECREASING_COLOR,
+ validate_ohlc,
+)
+from plotly.graph_objs import graph_objs
+
+
+def make_increasing_candle(open, high, low, close, dates, **kwargs):
+ """
+ Makes boxplot trace for increasing candlesticks
+
+ _make_increasing_candle() and _make_decreasing_candle separate the
+ increasing traces from the decreasing traces so kwargs (such as
+ color) can be passed separately to increasing or decreasing traces
+ when direction is set to 'increasing' or 'decreasing' in
+ FigureFactory.create_candlestick()
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (list) candle_incr_data: list of the box trace for
+ increasing candlesticks.
+ """
+ increase_x, increase_y = _Candlestick(
+ open, high, low, close, dates, **kwargs
+ ).get_candle_increase()
+
+ if "line" in kwargs:
+ kwargs.setdefault("fillcolor", kwargs["line"]["color"])
+ else:
+ kwargs.setdefault("fillcolor", _DEFAULT_INCREASING_COLOR)
+ if "name" in kwargs:
+ kwargs.setdefault("showlegend", True)
+ else:
+ kwargs.setdefault("showlegend", False)
+ kwargs.setdefault("name", "Increasing")
+ kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR))
+
+ candle_incr_data = dict(
+ type="box",
+ x=increase_x,
+ y=increase_y,
+ whiskerwidth=0,
+ boxpoints=False,
+ **kwargs,
+ )
+
+ return [candle_incr_data]
+
+
+def make_decreasing_candle(open, high, low, close, dates, **kwargs):
+ """
+ Makes boxplot trace for decreasing candlesticks
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to decreasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (list) candle_decr_data: list of the box trace for
+ decreasing candlesticks.
+ """
+
+ decrease_x, decrease_y = _Candlestick(
+ open, high, low, close, dates, **kwargs
+ ).get_candle_decrease()
+
+ if "line" in kwargs:
+ kwargs.setdefault("fillcolor", kwargs["line"]["color"])
+ else:
+ kwargs.setdefault("fillcolor", _DEFAULT_DECREASING_COLOR)
+ kwargs.setdefault("showlegend", False)
+ kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR))
+ kwargs.setdefault("name", "Decreasing")
+
+ candle_decr_data = dict(
+ type="box",
+ x=decrease_x,
+ y=decrease_y,
+ whiskerwidth=0,
+ boxpoints=False,
+ **kwargs,
+ )
+
+ return [candle_decr_data]
+
+
+def create_candlestick(open, high, low, close, dates=None, direction="both", **kwargs):
+ """
+ **deprecated**, use instead the plotly.graph_objects trace
+ :class:`plotly.graph_objects.Candlestick`
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param (string) direction: direction can be 'increasing', 'decreasing',
+ or 'both'. When the direction is 'increasing', the returned figure
+ consists of all candlesticks where the close value is greater than
+ the corresponding open value, and when the direction is
+ 'decreasing', the returned figure consists of all candlesticks
+ where the close value is less than or equal to the corresponding
+ open value. When the direction is 'both', both increasing and
+ decreasing candlesticks are returned. Default: 'both'
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
+ These kwargs describe other attributes about the ohlc Scatter trace
+ such as the color or the legend name. For more information on valid
+ kwargs call help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of candlestick chart figure.
+
+ Example 1: Simple candlestick chart from a Pandas DataFrame
+
+ >>> from plotly.figure_factory import create_candlestick
+ >>> from datetime import datetime
+ >>> import pandas as pd
+
+ >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
+ >>> fig = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
+ ... dates=df.index)
+ >>> fig.show()
+
+ Example 2: Customize the candlestick colors
+
+ >>> from plotly.figure_factory import create_candlestick
+ >>> from plotly.graph_objs import Line, Marker
+ >>> from datetime import datetime
+
+ >>> import pandas as pd
+ >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
+
+ >>> # Make increasing candlesticks and customize their color and name
+ >>> fig_increasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
+ ... dates=df.index,
+ ... direction='increasing', name='AAPL',
+ ... marker=Marker(color='rgb(150, 200, 250)'),
+ ... line=Line(color='rgb(150, 200, 250)'))
+
+ >>> # Make decreasing candlesticks and customize their color and name
+ >>> fig_decreasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
+ ... dates=df.index,
+ ... direction='decreasing',
+ ... marker=Marker(color='rgb(128, 128, 128)'),
+ ... line=Line(color='rgb(128, 128, 128)'))
+
+ >>> # Initialize the figure
+ >>> fig = fig_increasing
+
+ >>> # Add decreasing data with .extend()
+ >>> fig.add_trace(fig_decreasing['data']) # doctest: +SKIP
+ >>> fig.show()
+
+ Example 3: Candlestick chart with datetime objects
+
+ >>> from plotly.figure_factory import create_candlestick
+
+ >>> from datetime import datetime
+
+ >>> # Add data
+ >>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
+ >>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
+ >>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
+ >>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
+ >>> dates = [datetime(year=2013, month=10, day=10),
+ ... datetime(year=2013, month=11, day=10),
+ ... datetime(year=2013, month=12, day=10),
+ ... datetime(year=2014, month=1, day=10),
+ ... datetime(year=2014, month=2, day=10)]
+
+ >>> # Create ohlc
+ >>> fig = create_candlestick(open_data, high_data,
+ ... low_data, close_data, dates=dates)
+ >>> fig.show()
+ """
+ if dates is not None:
+ utils.validate_equal_length(open, high, low, close, dates)
+ else:
+ utils.validate_equal_length(open, high, low, close)
+ validate_ohlc(open, high, low, close, direction, **kwargs)
+
+ if direction == "increasing":
+ candle_incr_data = make_increasing_candle(
+ open, high, low, close, dates, **kwargs
+ )
+ data = candle_incr_data
+ elif direction == "decreasing":
+ candle_decr_data = make_decreasing_candle(
+ open, high, low, close, dates, **kwargs
+ )
+ data = candle_decr_data
+ else:
+ candle_incr_data = make_increasing_candle(
+ open, high, low, close, dates, **kwargs
+ )
+ candle_decr_data = make_decreasing_candle(
+ open, high, low, close, dates, **kwargs
+ )
+ data = candle_incr_data + candle_decr_data
+
+ layout = graph_objs.Layout()
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Candlestick(object):
+ """
+ Refer to FigureFactory.create_candlestick() for docstring.
+ """
+
+ def __init__(self, open, high, low, close, dates, **kwargs):
+ self.open = open
+ self.high = high
+ self.low = low
+ self.close = close
+ if dates is not None:
+ self.x = dates
+ else:
+ self.x = [x for x in range(len(self.open))]
+ self.get_candle_increase()
+
+ def get_candle_increase(self):
+ """
+ Separate increasing data from decreasing data.
+
+ The data is increasing when close value > open value
+ and decreasing when the close value <= open value.
+ """
+ increase_y = []
+ increase_x = []
+ for index in range(len(self.open)):
+ if self.close[index] > self.open[index]:
+ increase_y.append(self.low[index])
+ increase_y.append(self.open[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.high[index])
+ increase_x.append(self.x[index])
+
+ increase_x = [[x, x, x, x, x, x] for x in increase_x]
+ increase_x = utils.flatten(increase_x)
+
+ return increase_x, increase_y
+
+ def get_candle_decrease(self):
+ """
+ Separate increasing data from decreasing data.
+
+ The data is increasing when close value > open value
+ and decreasing when the close value <= open value.
+ """
+ decrease_y = []
+ decrease_x = []
+ for index in range(len(self.open)):
+ if self.close[index] <= self.open[index]:
+ decrease_y.append(self.low[index])
+ decrease_y.append(self.open[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.high[index])
+ decrease_x.append(self.x[index])
+
+ decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
+ decrease_x = utils.flatten(decrease_x)
+
+ return decrease_x, decrease_y
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py
new file mode 100644
index 0000000..7b397e7
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py
@@ -0,0 +1,1013 @@
+import io
+import numpy as np
+import os
+import pandas as pd
+import warnings
+
+from math import log, floor
+from numbers import Number
+
+from plotly import optional_imports
+import plotly.colors as clrs
+from plotly.figure_factory import utils
+from plotly.exceptions import PlotlyError
+import plotly.graph_objs as go
+
+pd.options.mode.chained_assignment = None
+
+shapely = optional_imports.get_module("shapely")
+shapefile = optional_imports.get_module("shapefile")
+gp = optional_imports.get_module("geopandas")
+_plotly_geo = optional_imports.get_module("_plotly_geo")
+
+
+def _create_us_counties_df(st_to_state_name_dict, state_to_st_dict):
+ # URLS
+ abs_dir_path = os.path.realpath(_plotly_geo.__file__)
+
+ abs_plotly_geo_path = os.path.dirname(abs_dir_path)
+
+ abs_package_data_dir_path = os.path.join(abs_plotly_geo_path, "package_data")
+
+ shape_pre2010 = "gz_2010_us_050_00_500k.shp"
+ shape_pre2010 = os.path.join(abs_package_data_dir_path, shape_pre2010)
+
+ df_shape_pre2010 = gp.read_file(shape_pre2010)
+ df_shape_pre2010["FIPS"] = df_shape_pre2010["STATE"] + df_shape_pre2010["COUNTY"]
+ df_shape_pre2010["FIPS"] = pd.to_numeric(df_shape_pre2010["FIPS"])
+
+ states_path = "cb_2016_us_state_500k.shp"
+ states_path = os.path.join(abs_package_data_dir_path, states_path)
+
+ df_state = gp.read_file(states_path)
+ df_state = df_state[["STATEFP", "NAME", "geometry"]]
+ df_state = df_state.rename(columns={"NAME": "STATE_NAME"})
+
+ filenames = [
+ "cb_2016_us_county_500k.dbf",
+ "cb_2016_us_county_500k.shp",
+ "cb_2016_us_county_500k.shx",
+ ]
+
+ for j in range(len(filenames)):
+ filenames[j] = os.path.join(abs_package_data_dir_path, filenames[j])
+
+ dbf = io.open(filenames[0], "rb")
+ shp = io.open(filenames[1], "rb")
+ shx = io.open(filenames[2], "rb")
+
+ r = shapefile.Reader(shp=shp, shx=shx, dbf=dbf)
+
+ attributes, geometry = [], []
+ field_names = [field[0] for field in r.fields[1:]]
+ for row in r.shapeRecords():
+ geometry.append(shapely.geometry.shape(row.shape.__geo_interface__))
+ attributes.append(dict(zip(field_names, row.record)))
+
+ gdf = gp.GeoDataFrame(data=attributes, geometry=geometry)
+
+ gdf["FIPS"] = gdf["STATEFP"] + gdf["COUNTYFP"]
+ gdf["FIPS"] = pd.to_numeric(gdf["FIPS"])
+
+ # add missing counties
+ f = 46113
+ singlerow = pd.DataFrame(
+ [
+ [
+ st_to_state_name_dict["SD"],
+ "SD",
+ df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["geometry"].iloc[0],
+ df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["FIPS"].iloc[0],
+ "46",
+ "Shannon",
+ ]
+ ],
+ columns=["State", "ST", "geometry", "FIPS", "STATEFP", "NAME"],
+ index=[max(gdf.index) + 1],
+ )
+ gdf = pd.concat([gdf, singlerow], sort=True)
+
+ f = 51515
+ singlerow = pd.DataFrame(
+ [
+ [
+ st_to_state_name_dict["VA"],
+ "VA",
+ df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["geometry"].iloc[0],
+ df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["FIPS"].iloc[0],
+ "51",
+ "Bedford City",
+ ]
+ ],
+ columns=["State", "ST", "geometry", "FIPS", "STATEFP", "NAME"],
+ index=[max(gdf.index) + 1],
+ )
+ gdf = pd.concat([gdf, singlerow], sort=True)
+
+ f = 2270
+ singlerow = pd.DataFrame(
+ [
+ [
+ st_to_state_name_dict["AK"],
+ "AK",
+ df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["geometry"].iloc[0],
+ df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["FIPS"].iloc[0],
+ "02",
+ "Wade Hampton",
+ ]
+ ],
+ columns=["State", "ST", "geometry", "FIPS", "STATEFP", "NAME"],
+ index=[max(gdf.index) + 1],
+ )
+ gdf = pd.concat([gdf, singlerow], sort=True)
+
+ row_2198 = gdf[gdf["FIPS"] == 2198]
+ row_2198.index = [max(gdf.index) + 1]
+ row_2198.loc[row_2198.index[0], "FIPS"] = 2201
+ row_2198.loc[row_2198.index[0], "STATEFP"] = "02"
+ gdf = pd.concat([gdf, row_2198], sort=True)
+
+ row_2105 = gdf[gdf["FIPS"] == 2105]
+ row_2105.index = [max(gdf.index) + 1]
+ row_2105.loc[row_2105.index[0], "FIPS"] = 2232
+ row_2105.loc[row_2105.index[0], "STATEFP"] = "02"
+ gdf = pd.concat([gdf, row_2105], sort=True)
+ gdf = gdf.rename(columns={"NAME": "COUNTY_NAME"})
+
+ gdf_reduced = gdf[["FIPS", "STATEFP", "COUNTY_NAME", "geometry"]]
+ gdf_statefp = gdf_reduced.merge(df_state[["STATEFP", "STATE_NAME"]], on="STATEFP")
+
+ ST = []
+ for n in gdf_statefp["STATE_NAME"]:
+ ST.append(state_to_st_dict[n])
+
+ gdf_statefp["ST"] = ST
+ return gdf_statefp, df_state
+
+
+st_to_state_name_dict = {
+ "AK": "Alaska",
+ "AL": "Alabama",
+ "AR": "Arkansas",
+ "AZ": "Arizona",
+ "CA": "California",
+ "CO": "Colorado",
+ "CT": "Connecticut",
+ "DC": "District of Columbia",
+ "DE": "Delaware",
+ "FL": "Florida",
+ "GA": "Georgia",
+ "HI": "Hawaii",
+ "IA": "Iowa",
+ "ID": "Idaho",
+ "IL": "Illinois",
+ "IN": "Indiana",
+ "KS": "Kansas",
+ "KY": "Kentucky",
+ "LA": "Louisiana",
+ "MA": "Massachusetts",
+ "MD": "Maryland",
+ "ME": "Maine",
+ "MI": "Michigan",
+ "MN": "Minnesota",
+ "MO": "Missouri",
+ "MS": "Mississippi",
+ "MT": "Montana",
+ "NC": "North Carolina",
+ "ND": "North Dakota",
+ "NE": "Nebraska",
+ "NH": "New Hampshire",
+ "NJ": "New Jersey",
+ "NM": "New Mexico",
+ "NV": "Nevada",
+ "NY": "New York",
+ "OH": "Ohio",
+ "OK": "Oklahoma",
+ "OR": "Oregon",
+ "PA": "Pennsylvania",
+ "RI": "Rhode Island",
+ "SC": "South Carolina",
+ "SD": "South Dakota",
+ "TN": "Tennessee",
+ "TX": "Texas",
+ "UT": "Utah",
+ "VA": "Virginia",
+ "VT": "Vermont",
+ "WA": "Washington",
+ "WI": "Wisconsin",
+ "WV": "West Virginia",
+ "WY": "Wyoming",
+}
+
+state_to_st_dict = {
+ "Alabama": "AL",
+ "Alaska": "AK",
+ "American Samoa": "AS",
+ "Arizona": "AZ",
+ "Arkansas": "AR",
+ "California": "CA",
+ "Colorado": "CO",
+ "Commonwealth of the Northern Mariana Islands": "MP",
+ "Connecticut": "CT",
+ "Delaware": "DE",
+ "District of Columbia": "DC",
+ "Florida": "FL",
+ "Georgia": "GA",
+ "Guam": "GU",
+ "Hawaii": "HI",
+ "Idaho": "ID",
+ "Illinois": "IL",
+ "Indiana": "IN",
+ "Iowa": "IA",
+ "Kansas": "KS",
+ "Kentucky": "KY",
+ "Louisiana": "LA",
+ "Maine": "ME",
+ "Maryland": "MD",
+ "Massachusetts": "MA",
+ "Michigan": "MI",
+ "Minnesota": "MN",
+ "Mississippi": "MS",
+ "Missouri": "MO",
+ "Montana": "MT",
+ "Nebraska": "NE",
+ "Nevada": "NV",
+ "New Hampshire": "NH",
+ "New Jersey": "NJ",
+ "New Mexico": "NM",
+ "New York": "NY",
+ "North Carolina": "NC",
+ "North Dakota": "ND",
+ "Ohio": "OH",
+ "Oklahoma": "OK",
+ "Oregon": "OR",
+ "Pennsylvania": "PA",
+ "Puerto Rico": "",
+ "Rhode Island": "RI",
+ "South Carolina": "SC",
+ "South Dakota": "SD",
+ "Tennessee": "TN",
+ "Texas": "TX",
+ "United States Virgin Islands": "VI",
+ "Utah": "UT",
+ "Vermont": "VT",
+ "Virginia": "VA",
+ "Washington": "WA",
+ "West Virginia": "WV",
+ "Wisconsin": "WI",
+ "Wyoming": "WY",
+}
+
+USA_XRANGE = [-125.0, -65.0]
+USA_YRANGE = [25.0, 49.0]
+
+
+def _human_format(number):
+ units = ["", "K", "M", "G", "T", "P"]
+ k = 1000.0
+ magnitude = int(floor(log(number, k)))
+ return "%.2f%s" % (number / k**magnitude, units[magnitude])
+
+
+def _intervals_as_labels(array_of_intervals, round_legend_values, exponent_format):
+ """
+ Transform an number interval to a clean string for legend
+
+ Example: [-inf, 30] to '< 30'
+ """
+ infs = [float("-inf"), float("inf")]
+ string_intervals = []
+ for interval in array_of_intervals:
+ # round to 2nd decimal place
+ if round_legend_values:
+ rnd_interval = [
+ (int(interval[i]) if interval[i] not in infs else interval[i])
+ for i in range(2)
+ ]
+ else:
+ rnd_interval = [round(interval[0], 2), round(interval[1], 2)]
+
+ num0 = rnd_interval[0]
+ num1 = rnd_interval[1]
+ if exponent_format:
+ if num0 not in infs:
+ num0 = _human_format(num0)
+ if num1 not in infs:
+ num1 = _human_format(num1)
+ else:
+ if num0 not in infs:
+ num0 = "{:,}".format(num0)
+ if num1 not in infs:
+ num1 = "{:,}".format(num1)
+
+ if num0 == float("-inf"):
+ as_str = "< {}".format(num1)
+ elif num1 == float("inf"):
+ as_str = "> {}".format(num0)
+ else:
+ as_str = "{} - {}".format(num0, num1)
+ string_intervals.append(as_str)
+ return string_intervals
+
+
+def _calculations(
+ df,
+ fips,
+ values,
+ index,
+ f,
+ simplify_county,
+ level,
+ x_centroids,
+ y_centroids,
+ centroid_text,
+ x_traces,
+ y_traces,
+ fips_polygon_map,
+):
+ # 0-pad FIPS code to ensure exactly 5 digits
+ padded_f = str(f).zfill(5)
+ if fips_polygon_map[f].type == "Polygon":
+ x = fips_polygon_map[f].simplify(simplify_county).exterior.xy[0].tolist()
+ y = fips_polygon_map[f].simplify(simplify_county).exterior.xy[1].tolist()
+
+ x_c, y_c = fips_polygon_map[f].centroid.xy
+ county_name_str = str(df[df["FIPS"] == f]["COUNTY_NAME"].iloc[0])
+ state_name_str = str(df[df["FIPS"] == f]["STATE_NAME"].iloc[0])
+
+ t_c = (
+ "County: "
+ + county_name_str
+ + "<br>"
+ + "State: "
+ + state_name_str
+ + "<br>"
+ + "FIPS: "
+ + padded_f
+ + "<br>Value: "
+ + str(values[index])
+ )
+
+ x_centroids.append(x_c[0])
+ y_centroids.append(y_c[0])
+ centroid_text.append(t_c)
+
+ x_traces[level] = x_traces[level] + x + [np.nan]
+ y_traces[level] = y_traces[level] + y + [np.nan]
+ elif fips_polygon_map[f].type == "MultiPolygon":
+ x = [
+ poly.simplify(simplify_county).exterior.xy[0].tolist()
+ for poly in fips_polygon_map[f].geoms
+ ]
+ y = [
+ poly.simplify(simplify_county).exterior.xy[1].tolist()
+ for poly in fips_polygon_map[f].geoms
+ ]
+
+ x_c = [poly.centroid.xy[0].tolist() for poly in fips_polygon_map[f].geoms]
+ y_c = [poly.centroid.xy[1].tolist() for poly in fips_polygon_map[f].geoms]
+
+ county_name_str = str(df[df["FIPS"] == f]["COUNTY_NAME"].iloc[0])
+ state_name_str = str(df[df["FIPS"] == f]["STATE_NAME"].iloc[0])
+ text = (
+ "County: "
+ + county_name_str
+ + "<br>"
+ + "State: "
+ + state_name_str
+ + "<br>"
+ + "FIPS: "
+ + padded_f
+ + "<br>Value: "
+ + str(values[index])
+ )
+ t_c = [text for poly in fips_polygon_map[f].geoms]
+ x_centroids = x_c + x_centroids
+ y_centroids = y_c + y_centroids
+ centroid_text = t_c + centroid_text
+ for x_y_idx in range(len(x)):
+ x_traces[level] = x_traces[level] + x[x_y_idx] + [np.nan]
+ y_traces[level] = y_traces[level] + y[x_y_idx] + [np.nan]
+
+ return x_traces, y_traces, x_centroids, y_centroids, centroid_text
+
+
+def create_choropleth(
+ fips,
+ values,
+ scope=["usa"],
+ binning_endpoints=None,
+ colorscale=None,
+ order=None,
+ simplify_county=0.02,
+ simplify_state=0.02,
+ asp=None,
+ show_hover=True,
+ show_state_data=True,
+ state_outline=None,
+ county_outline=None,
+ centroid_marker=None,
+ round_legend_values=False,
+ exponent_format=False,
+ legend_title="",
+ **layout_options,
+):
+ """
+ **deprecated**, use instead
+ :func:`plotly.express.choropleth` with custom GeoJSON.
+
+ This function also requires `shapely`, `geopandas` and `plotly-geo` to be installed.
+
+ Returns figure for county choropleth. Uses data from package_data.
+
+ :param (list) fips: list of FIPS values which correspond to the con
+ catination of state and county ids. An example is '01001'.
+ :param (list) values: list of numbers/strings which correspond to the
+ fips list. These are the values that will determine how the counties
+ are colored.
+ :param (list) scope: list of states and/or states abbreviations. Fits
+ all states in the camera tightly. Selecting ['usa'] is the equivalent
+ of appending all 50 states into your scope list. Selecting only 'usa'
+ does not include 'Alaska', 'Puerto Rico', 'American Samoa',
+ 'Commonwealth of the Northern Mariana Islands', 'Guam',
+ 'United States Virgin Islands'. These must be added manually to the
+ list.
+ Default = ['usa']
+ :param (list) binning_endpoints: ascending numbers which implicitly define
+ real number intervals which are used as bins. The colorscale used must
+ have the same number of colors as the number of bins and this will
+ result in a categorical colormap.
+ :param (list) colorscale: a list of colors with length equal to the
+ number of categories of colors. The length must match either all
+ unique numbers in the 'values' list or if endpoints is being used, the
+ number of categories created by the endpoints.\n
+ For example, if binning_endpoints = [4, 6, 8], then there are 4 bins:
+ [-inf, 4), [4, 6), [6, 8), [8, inf)
+ :param (list) order: a list of the unique categories (numbers/bins) in any
+ desired order. This is helpful if you want to order string values to
+ a chosen colorscale.
+ :param (float) simplify_county: determines the simplification factor
+ for the counties. The larger the number, the fewer vertices and edges
+ each polygon has. See
+ http://toblerity.org/shapely/manual.html#object.simplify for more
+ information.
+ Default = 0.02
+ :param (float) simplify_state: simplifies the state outline polygon.
+ See http://toblerity.org/shapely/manual.html#object.simplify for more
+ information.
+ Default = 0.02
+ :param (float) asp: the width-to-height aspect ratio for the camera.
+ Default = 2.5
+ :param (bool) show_hover: show county hover and centroid info
+ :param (bool) show_state_data: reveals state boundary lines
+ :param (dict) state_outline: dict of attributes of the state outline
+ including width and color. See
+ https://plot.ly/python/reference/#scatter-marker-line for all valid
+ params
+ :param (dict) county_outline: dict of attributes of the county outline
+ including width and color. See
+ https://plot.ly/python/reference/#scatter-marker-line for all valid
+ params
+ :param (dict) centroid_marker: dict of attributes of the centroid marker.
+ The centroid markers are invisible by default and appear visible on
+ selection. See https://plot.ly/python/reference/#scatter-marker for
+ all valid params
+ :param (bool) round_legend_values: automatically round the numbers that
+ appear in the legend to the nearest integer.
+ Default = False
+ :param (bool) exponent_format: if set to True, puts numbers in the K, M,
+ B number format. For example 4000.0 becomes 4.0K
+ Default = False
+ :param (str) legend_title: title that appears above the legend
+ :param **layout_options: a **kwargs argument for all layout parameters
+
+
+ Example 1: Florida::
+
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import numpy as np
+ import pandas as pd
+
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv'
+ )
+ df_sample_r = df_sample[df_sample['STNAME'] == 'Florida']
+
+ values = df_sample_r['TOT_POP'].tolist()
+ fips = df_sample_r['FIPS'].tolist()
+
+ binning_endpoints = list(np.mgrid[min(values):max(values):4j])
+ colorscale = ["#030512","#1d1d3b","#323268","#3d4b94","#3e6ab0",
+ "#4989bc","#60a7c7","#85c5d3","#b7e0e4","#eafcfd"]
+ fig = ff.create_choropleth(
+ fips=fips, values=values, scope=['Florida'], show_state_data=True,
+ colorscale=colorscale, binning_endpoints=binning_endpoints,
+ round_legend_values=True, plot_bgcolor='rgb(229,229,229)',
+ paper_bgcolor='rgb(229,229,229)', legend_title='Florida Population',
+ county_outline={'color': 'rgb(255,255,255)', 'width': 0.5},
+ exponent_format=True,
+ )
+
+ Example 2: New England::
+
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ NE_states = ['Connecticut', 'Maine', 'Massachusetts',
+ 'New Hampshire', 'Rhode Island']
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv'
+ )
+ df_sample_r = df_sample[df_sample['STNAME'].isin(NE_states)]
+ colorscale = ['rgb(68.0, 1.0, 84.0)',
+ 'rgb(66.0, 64.0, 134.0)',
+ 'rgb(38.0, 130.0, 142.0)',
+ 'rgb(63.0, 188.0, 115.0)',
+ 'rgb(216.0, 226.0, 25.0)']
+
+ values = df_sample_r['TOT_POP'].tolist()
+ fips = df_sample_r['FIPS'].tolist()
+ fig = ff.create_choropleth(
+ fips=fips, values=values, scope=NE_states, show_state_data=True
+ )
+ fig.show()
+
+ Example 3: California and Surrounding States::
+
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv'
+ )
+ df_sample_r = df_sample[df_sample['STNAME'] == 'California']
+
+ values = df_sample_r['TOT_POP'].tolist()
+ fips = df_sample_r['FIPS'].tolist()
+
+ colorscale = [
+ 'rgb(193, 193, 193)',
+ 'rgb(239,239,239)',
+ 'rgb(195, 196, 222)',
+ 'rgb(144,148,194)',
+ 'rgb(101,104,168)',
+ 'rgb(65, 53, 132)'
+ ]
+
+ fig = ff.create_choropleth(
+ fips=fips, values=values, colorscale=colorscale,
+ scope=['CA', 'AZ', 'Nevada', 'Oregon', ' Idaho'],
+ binning_endpoints=[14348, 63983, 134827, 426762, 2081313],
+ county_outline={'color': 'rgb(255,255,255)', 'width': 0.5},
+ legend_title='California Counties',
+ title='California and Nearby States'
+ )
+ fig.show()
+
+ Example 4: USA::
+
+ import plotly.figure_factory as ff
+
+ import numpy as np
+ import pandas as pd
+
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/laucnty16.csv'
+ )
+ df_sample['State FIPS Code'] = df_sample['State FIPS Code'].apply(
+ lambda x: str(x).zfill(2)
+ )
+ df_sample['County FIPS Code'] = df_sample['County FIPS Code'].apply(
+ lambda x: str(x).zfill(3)
+ )
+ df_sample['FIPS'] = (
+ df_sample['State FIPS Code'] + df_sample['County FIPS Code']
+ )
+
+ binning_endpoints = list(np.linspace(1, 12, len(colorscale) - 1))
+ colorscale = ["#f7fbff", "#ebf3fb", "#deebf7", "#d2e3f3", "#c6dbef",
+ "#b3d2e9", "#9ecae1", "#85bcdb", "#6baed6", "#57a0ce",
+ "#4292c6", "#3082be", "#2171b5", "#1361a9", "#08519c",
+ "#0b4083","#08306b"]
+ fips = df_sample['FIPS']
+ values = df_sample['Unemployment Rate (%)']
+ fig = ff.create_choropleth(
+ fips=fips, values=values, scope=['usa'],
+ binning_endpoints=binning_endpoints, colorscale=colorscale,
+ show_hover=True, centroid_marker={'opacity': 0},
+ asp=2.9, title='USA by Unemployment %',
+ legend_title='Unemployment %'
+ )
+ fig.show()
+ """
+ # ensure optional modules imported
+ if not _plotly_geo:
+ raise ValueError(
+ """
+The create_choropleth figure factory requires the plotly-geo package.
+Install using pip with:
+
+$ pip install plotly-geo
+
+Or, install using conda with
+
+$ conda install -c plotly plotly-geo
+"""
+ )
+
+ if not gp or not shapefile or not shapely:
+ raise ImportError(
+ "geopandas, pyshp and shapely must be installed for this figure "
+ "factory.\n\nRun the following commands to install the correct "
+ "versions of the following modules:\n\n"
+ "```\n"
+ "$ pip install geopandas==0.3.0\n"
+ "$ pip install pyshp==1.2.10\n"
+ "$ pip install shapely==1.6.3\n"
+ "```\n"
+ "If you are using Windows, follow this post to properly "
+ "install geopandas and dependencies:"
+ "http://geoffboeing.com/2014/09/using-geopandas-windows/\n\n"
+ "If you are using Anaconda, do not use PIP to install the "
+ "packages above. Instead use conda to install them:\n\n"
+ "```\n"
+ "$ conda install plotly\n"
+ "$ conda install geopandas\n"
+ "```"
+ )
+
+ df, df_state = _create_us_counties_df(st_to_state_name_dict, state_to_st_dict)
+
+ fips_polygon_map = dict(zip(df["FIPS"].tolist(), df["geometry"].tolist()))
+
+ if not state_outline:
+ state_outline = {"color": "rgb(240, 240, 240)", "width": 1}
+ if not county_outline:
+ county_outline = {"color": "rgb(0, 0, 0)", "width": 0}
+ if not centroid_marker:
+ centroid_marker = {"size": 3, "color": "white", "opacity": 1}
+
+ # ensure centroid markers appear on selection
+ if "opacity" not in centroid_marker:
+ centroid_marker.update({"opacity": 1})
+
+ if len(fips) != len(values):
+ raise PlotlyError("fips and values must be the same length")
+
+ # make fips, values into lists
+ if isinstance(fips, pd.core.series.Series):
+ fips = fips.tolist()
+ if isinstance(values, pd.core.series.Series):
+ values = values.tolist()
+
+ # make fips numeric
+ fips = map(lambda x: int(x), fips)
+
+ if binning_endpoints:
+ intervals = utils.endpts_to_intervals(binning_endpoints)
+ LEVELS = _intervals_as_labels(intervals, round_legend_values, exponent_format)
+ else:
+ if not order:
+ LEVELS = sorted(list(set(values)))
+ else:
+ # check if order is permutation
+ # of unique color col values
+ same_sets = sorted(list(set(values))) == set(order)
+ no_duplicates = not any(order.count(x) > 1 for x in order)
+ if same_sets and no_duplicates:
+ LEVELS = order
+ else:
+ raise PlotlyError(
+ "if you are using a custom order of unique values from "
+ "your color column, you must: have all the unique values "
+ "in your order and have no duplicate items"
+ )
+
+ if not colorscale:
+ colorscale = []
+ viridis_colors = clrs.colorscale_to_colors(clrs.PLOTLY_SCALES["Viridis"])
+ viridis_colors = clrs.color_parser(viridis_colors, clrs.hex_to_rgb)
+ viridis_colors = clrs.color_parser(viridis_colors, clrs.label_rgb)
+ viri_len = len(viridis_colors) + 1
+ viri_intervals = utils.endpts_to_intervals(list(np.linspace(0, 1, viri_len)))[
+ 1:-1
+ ]
+
+ for L in np.linspace(0, 1, len(LEVELS)):
+ for idx, inter in enumerate(viri_intervals):
+ if L == 0:
+ break
+ elif inter[0] < L <= inter[1]:
+ break
+
+ intermed = (L - viri_intervals[idx][0]) / (
+ viri_intervals[idx][1] - viri_intervals[idx][0]
+ )
+
+ float_color = clrs.find_intermediate_color(
+ viridis_colors[idx], viridis_colors[idx], intermed, colortype="rgb"
+ )
+
+ # make R,G,B into int values
+ float_color = clrs.unlabel_rgb(float_color)
+ float_color = clrs.unconvert_from_RGB_255(float_color)
+ int_rgb = clrs.convert_to_RGB_255(float_color)
+ int_rgb = clrs.label_rgb(int_rgb)
+
+ colorscale.append(int_rgb)
+
+ if len(colorscale) < len(LEVELS):
+ raise PlotlyError(
+ "You have {} LEVELS. Your number of colors in 'colorscale' must "
+ "be at least the number of LEVELS: {}. If you are "
+ "using 'binning_endpoints' then 'colorscale' must have at "
+ "least len(binning_endpoints) + 2 colors".format(
+ len(LEVELS), min(LEVELS, LEVELS[:20])
+ )
+ )
+
+ color_lookup = dict(zip(LEVELS, colorscale))
+ x_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))]))
+ y_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))]))
+
+ # scope
+ if isinstance(scope, str):
+ raise PlotlyError("'scope' must be a list/tuple/sequence")
+
+ scope_names = []
+ extra_states = [
+ "Alaska",
+ "Commonwealth of the Northern Mariana Islands",
+ "Puerto Rico",
+ "Guam",
+ "United States Virgin Islands",
+ "American Samoa",
+ ]
+ for state in scope:
+ if state.lower() == "usa":
+ scope_names = df["STATE_NAME"].unique()
+ scope_names = list(scope_names)
+ for ex_st in extra_states:
+ try:
+ scope_names.remove(ex_st)
+ except ValueError:
+ pass
+ else:
+ if state in st_to_state_name_dict.keys():
+ state = st_to_state_name_dict[state]
+ scope_names.append(state)
+ df_state = df_state[df_state["STATE_NAME"].isin(scope_names)]
+
+ plot_data = []
+ x_centroids = []
+ y_centroids = []
+ centroid_text = []
+ fips_not_in_shapefile = []
+ if not binning_endpoints:
+ for index, f in enumerate(fips):
+ level = values[index]
+ try:
+ fips_polygon_map[f].type
+
+ (
+ x_traces,
+ y_traces,
+ x_centroids,
+ y_centroids,
+ centroid_text,
+ ) = _calculations(
+ df,
+ fips,
+ values,
+ index,
+ f,
+ simplify_county,
+ level,
+ x_centroids,
+ y_centroids,
+ centroid_text,
+ x_traces,
+ y_traces,
+ fips_polygon_map,
+ )
+ except KeyError:
+ fips_not_in_shapefile.append(f)
+
+ else:
+ for index, f in enumerate(fips):
+ for j, inter in enumerate(intervals):
+ if inter[0] < values[index] <= inter[1]:
+ break
+ level = LEVELS[j]
+
+ try:
+ fips_polygon_map[f].type
+
+ (
+ x_traces,
+ y_traces,
+ x_centroids,
+ y_centroids,
+ centroid_text,
+ ) = _calculations(
+ df,
+ fips,
+ values,
+ index,
+ f,
+ simplify_county,
+ level,
+ x_centroids,
+ y_centroids,
+ centroid_text,
+ x_traces,
+ y_traces,
+ fips_polygon_map,
+ )
+ except KeyError:
+ fips_not_in_shapefile.append(f)
+
+ if len(fips_not_in_shapefile) > 0:
+ msg = (
+ "Unrecognized FIPS Values\n\nWhoops! It looks like you are "
+ "trying to pass at least one FIPS value that is not in "
+ "our shapefile of FIPS and data for the counties. Your "
+ "choropleth will still show up but these counties cannot "
+ "be shown.\nUnrecognized FIPS are: {}".format(fips_not_in_shapefile)
+ )
+ warnings.warn(msg)
+
+ x_states = []
+ y_states = []
+ for index, row in df_state.iterrows():
+ if df_state["geometry"][index].type == "Polygon":
+ x = row.geometry.simplify(simplify_state).exterior.xy[0].tolist()
+ y = row.geometry.simplify(simplify_state).exterior.xy[1].tolist()
+ x_states = x_states + x
+ y_states = y_states + y
+ elif df_state["geometry"][index].type == "MultiPolygon":
+ x = [
+ poly.simplify(simplify_state).exterior.xy[0].tolist()
+ for poly in df_state["geometry"][index].geoms
+ ]
+ y = [
+ poly.simplify(simplify_state).exterior.xy[1].tolist()
+ for poly in df_state["geometry"][index].geoms
+ ]
+ for segment in range(len(x)):
+ x_states = x_states + x[segment]
+ y_states = y_states + y[segment]
+ x_states.append(np.nan)
+ y_states.append(np.nan)
+ x_states.append(np.nan)
+ y_states.append(np.nan)
+
+ for lev in LEVELS:
+ county_data = dict(
+ type="scatter",
+ mode="lines",
+ x=x_traces[lev],
+ y=y_traces[lev],
+ line=county_outline,
+ fill="toself",
+ fillcolor=color_lookup[lev],
+ name=lev,
+ hoverinfo="none",
+ )
+ plot_data.append(county_data)
+
+ if show_hover:
+ hover_points = dict(
+ type="scatter",
+ showlegend=False,
+ legendgroup="centroids",
+ x=x_centroids,
+ y=y_centroids,
+ text=centroid_text,
+ name="US Counties",
+ mode="markers",
+ marker={"color": "white", "opacity": 0},
+ hoverinfo="text",
+ )
+ centroids_on_select = dict(
+ selected=dict(marker=centroid_marker),
+ unselected=dict(marker=dict(opacity=0)),
+ )
+ hover_points.update(centroids_on_select)
+ plot_data.append(hover_points)
+
+ if show_state_data:
+ state_data = dict(
+ type="scatter",
+ legendgroup="States",
+ line=state_outline,
+ x=x_states,
+ y=y_states,
+ hoverinfo="text",
+ showlegend=False,
+ mode="lines",
+ )
+ plot_data.append(state_data)
+
+ DEFAULT_LAYOUT = dict(
+ hovermode="closest",
+ xaxis=dict(
+ autorange=False,
+ range=USA_XRANGE,
+ showgrid=False,
+ zeroline=False,
+ fixedrange=True,
+ showticklabels=False,
+ ),
+ yaxis=dict(
+ autorange=False,
+ range=USA_YRANGE,
+ showgrid=False,
+ zeroline=False,
+ fixedrange=True,
+ showticklabels=False,
+ ),
+ margin=dict(t=40, b=20, r=20, l=20),
+ width=900,
+ height=450,
+ dragmode="select",
+ legend=dict(traceorder="reversed", xanchor="right", yanchor="top", x=1, y=1),
+ annotations=[],
+ )
+ fig = dict(data=plot_data, layout=DEFAULT_LAYOUT)
+ fig["layout"].update(layout_options)
+ fig["layout"]["annotations"].append(
+ dict(
+ x=1,
+ y=1.05,
+ xref="paper",
+ yref="paper",
+ xanchor="right",
+ showarrow=False,
+ text="<b>" + legend_title + "</b>",
+ )
+ )
+
+ if len(scope) == 1 and scope[0].lower() == "usa":
+ xaxis_range_low = -125.0
+ xaxis_range_high = -55.0
+ yaxis_range_low = 25.0
+ yaxis_range_high = 49.0
+ else:
+ xaxis_range_low = float("inf")
+ xaxis_range_high = float("-inf")
+ yaxis_range_low = float("inf")
+ yaxis_range_high = float("-inf")
+ for trace in fig["data"]:
+ if all(isinstance(n, Number) for n in trace["x"]):
+ calc_x_min = min(trace["x"] or [float("inf")])
+ calc_x_max = max(trace["x"] or [float("-inf")])
+ if calc_x_min < xaxis_range_low:
+ xaxis_range_low = calc_x_min
+ if calc_x_max > xaxis_range_high:
+ xaxis_range_high = calc_x_max
+ if all(isinstance(n, Number) for n in trace["y"]):
+ calc_y_min = min(trace["y"] or [float("inf")])
+ calc_y_max = max(trace["y"] or [float("-inf")])
+ if calc_y_min < yaxis_range_low:
+ yaxis_range_low = calc_y_min
+ if calc_y_max > yaxis_range_high:
+ yaxis_range_high = calc_y_max
+
+ # camera zoom
+ fig["layout"]["xaxis"]["range"] = [xaxis_range_low, xaxis_range_high]
+ fig["layout"]["yaxis"]["range"] = [yaxis_range_low, yaxis_range_high]
+
+ # aspect ratio
+ if asp is None:
+ usa_x_range = USA_XRANGE[1] - USA_XRANGE[0]
+ usa_y_range = USA_YRANGE[1] - USA_YRANGE[0]
+ asp = usa_x_range / usa_y_range
+
+ # based on your figure
+ width = float(
+ fig["layout"]["xaxis"]["range"][1] - fig["layout"]["xaxis"]["range"][0]
+ )
+ height = float(
+ fig["layout"]["yaxis"]["range"][1] - fig["layout"]["yaxis"]["range"][0]
+ )
+
+ center = (
+ sum(fig["layout"]["xaxis"]["range"]) / 2.0,
+ sum(fig["layout"]["yaxis"]["range"]) / 2.0,
+ )
+
+ if height / width > (1 / asp):
+ new_width = asp * height
+ fig["layout"]["xaxis"]["range"][0] = center[0] - new_width * 0.5
+ fig["layout"]["xaxis"]["range"][1] = center[0] + new_width * 0.5
+ else:
+ new_height = (1 / asp) * width
+ fig["layout"]["yaxis"]["range"][0] = center[1] - new_height * 0.5
+ fig["layout"]["yaxis"]["range"][1] = center[1] + new_height * 0.5
+
+ return go.Figure(fig)
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py
new file mode 100644
index 0000000..fd6d505
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py
@@ -0,0 +1,395 @@
+from collections import OrderedDict
+
+from plotly import exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module("numpy")
+scp = optional_imports.get_module("scipy")
+sch = optional_imports.get_module("scipy.cluster.hierarchy")
+scs = optional_imports.get_module("scipy.spatial")
+
+
+def create_dendrogram(
+ X,
+ orientation="bottom",
+ labels=None,
+ colorscale=None,
+ distfun=None,
+ linkagefun=lambda x: sch.linkage(x, "complete"),
+ hovertext=None,
+ color_threshold=None,
+):
+ """
+ Function that returns a dendrogram Plotly figure object. This is a thin
+ wrapper around scipy.cluster.hierarchy.dendrogram.
+
+ See also https://dash.plot.ly/dash-bio/clustergram.
+
+ :param (ndarray) X: Matrix of observations as array of arrays
+ :param (str) orientation: 'top', 'right', 'bottom', or 'left'
+ :param (list) labels: List of axis category labels(observation labels)
+ :param (list) colorscale: Optional colorscale for the dendrogram tree.
+ Requires 8 colors to be specified, the 7th of
+ which is ignored. With scipy>=1.5.0, the 2nd, 3rd
+ and 6th are used twice as often as the others.
+ Given a shorter list, the missing values are
+ replaced with defaults and with a longer list the
+ extra values are ignored.
+ :param (function) distfun: Function to compute the pairwise distance from
+ the observations
+ :param (function) linkagefun: Function to compute the linkage matrix from
+ the pairwise distances
+ :param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
+ clusters
+ :param (double) color_threshold: Value at which the separation of clusters will be made
+
+ Example 1: Simple bottom oriented dendrogram
+
+ >>> from plotly.figure_factory import create_dendrogram
+
+ >>> import numpy as np
+
+ >>> X = np.random.rand(10,10)
+ >>> fig = create_dendrogram(X)
+ >>> fig.show()
+
+ Example 2: Dendrogram to put on the left of the heatmap
+
+ >>> from plotly.figure_factory import create_dendrogram
+
+ >>> import numpy as np
+
+ >>> X = np.random.rand(5,5)
+ >>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
+ >>> dendro = create_dendrogram(X, orientation='right', labels=names)
+ >>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
+ >>> dendro.show()
+
+ Example 3: Dendrogram with Pandas
+
+ >>> from plotly.figure_factory import create_dendrogram
+
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> Index= ['A','B','C','D','E','F','G','H','I','J']
+ >>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
+ >>> fig = create_dendrogram(df, labels=Index)
+ >>> fig.show()
+ """
+ if not scp or not scs or not sch:
+ raise ImportError(
+ "FigureFactory.create_dendrogram requires scipy, \
+ scipy.spatial and scipy.hierarchy"
+ )
+
+ s = X.shape
+ if len(s) != 2:
+ exceptions.PlotlyError("X should be 2-dimensional array.")
+
+ if distfun is None:
+ distfun = scs.distance.pdist
+
+ dendrogram = _Dendrogram(
+ X,
+ orientation,
+ labels,
+ colorscale,
+ distfun=distfun,
+ linkagefun=linkagefun,
+ hovertext=hovertext,
+ color_threshold=color_threshold,
+ )
+
+ return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
+
+
+class _Dendrogram(object):
+ """Refer to FigureFactory.create_dendrogram() for docstring."""
+
+ def __init__(
+ self,
+ X,
+ orientation="bottom",
+ labels=None,
+ colorscale=None,
+ width=np.inf,
+ height=np.inf,
+ xaxis="xaxis",
+ yaxis="yaxis",
+ distfun=None,
+ linkagefun=lambda x: sch.linkage(x, "complete"),
+ hovertext=None,
+ color_threshold=None,
+ ):
+ self.orientation = orientation
+ self.labels = labels
+ self.xaxis = xaxis
+ self.yaxis = yaxis
+ self.data = []
+ self.leaves = []
+ self.sign = {self.xaxis: 1, self.yaxis: 1}
+ self.layout = {self.xaxis: {}, self.yaxis: {}}
+
+ if self.orientation in ["left", "bottom"]:
+ self.sign[self.xaxis] = 1
+ else:
+ self.sign[self.xaxis] = -1
+
+ if self.orientation in ["right", "bottom"]:
+ self.sign[self.yaxis] = 1
+ else:
+ self.sign[self.yaxis] = -1
+
+ if distfun is None:
+ distfun = scs.distance.pdist
+
+ (dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
+ X, colorscale, distfun, linkagefun, hovertext, color_threshold
+ )
+
+ self.labels = ordered_labels
+ self.leaves = leaves
+ yvals_flat = yvals.flatten()
+ xvals_flat = xvals.flatten()
+
+ self.zero_vals = []
+
+ for i in range(len(yvals_flat)):
+ if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
+ self.zero_vals.append(xvals_flat[i])
+
+ if len(self.zero_vals) > len(yvals) + 1:
+ # If the length of zero_vals is larger than the length of yvals,
+ # it means that there are wrong vals because of the identicial samples.
+ # Three and more identicial samples will make the yvals of spliting
+ # center into 0 and it will accidentally take it as leaves.
+ l_border = int(min(self.zero_vals))
+ r_border = int(max(self.zero_vals))
+ correct_leaves_pos = range(
+ l_border, r_border + 1, int((r_border - l_border) / len(yvals))
+ )
+ # Regenerating the leaves pos from the self.zero_vals with equally intervals.
+ self.zero_vals = [v for v in correct_leaves_pos]
+
+ self.zero_vals.sort()
+ self.layout = self.set_figure_layout(width, height)
+ self.data = dd_traces
+
+ def get_color_dict(self, colorscale):
+ """
+ Returns colorscale used for dendrogram tree clusters.
+
+ :param (list) colorscale: Colors to use for the plot in rgb format.
+ :rtype (dict): A dict of default colors mapped to the user colorscale.
+
+ """
+
+ # These are the color codes returned for dendrograms
+ # We're replacing them with nicer colors
+ # This list is the colors that can be used by dendrogram, which were
+ # determined as the combination of the default above_threshold_color and
+ # the default color palette (see scipy/cluster/hierarchy.py)
+ d = {
+ "r": "red",
+ "g": "green",
+ "b": "blue",
+ "c": "cyan",
+ "m": "magenta",
+ "y": "yellow",
+ "k": "black",
+ # TODO: 'w' doesn't seem to be in the default color
+ # palette in scipy/cluster/hierarchy.py
+ "w": "white",
+ }
+ default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
+
+ if colorscale is None:
+ rgb_colorscale = [
+ "rgb(0,116,217)", # blue
+ "rgb(35,205,205)", # cyan
+ "rgb(61,153,112)", # green
+ "rgb(40,35,35)", # black
+ "rgb(133,20,75)", # magenta
+ "rgb(255,65,54)", # red
+ "rgb(255,255,255)", # white
+ "rgb(255,220,0)", # yellow
+ ]
+ else:
+ rgb_colorscale = colorscale
+
+ for i in range(len(default_colors.keys())):
+ k = list(default_colors.keys())[i] # PY3 won't index keys
+ if i < len(rgb_colorscale):
+ default_colors[k] = rgb_colorscale[i]
+
+ # add support for cyclic format colors as introduced in scipy===1.5.0
+ # before this, the colors were named 'r', 'b', 'y' etc., now they are
+ # named 'C0', 'C1', etc. To keep the colors consistent regardless of the
+ # scipy version, we try as much as possible to map the new colors to the
+ # old colors
+ # this mapping was found by inpecting scipy/cluster/hierarchy.py (see
+ # comment above).
+ new_old_color_map = [
+ ("C0", "b"),
+ ("C1", "g"),
+ ("C2", "r"),
+ ("C3", "c"),
+ ("C4", "m"),
+ ("C5", "y"),
+ ("C6", "k"),
+ ("C7", "g"),
+ ("C8", "r"),
+ ("C9", "c"),
+ ]
+ for nc, oc in new_old_color_map:
+ try:
+ default_colors[nc] = default_colors[oc]
+ except KeyError:
+ # it could happen that the old color isn't found (if a custom
+ # colorscale was specified), in this case we set it to an
+ # arbitrary default.
+ default_colors[nc] = "rgb(0,116,217)"
+
+ return default_colors
+
+ def set_axis_layout(self, axis_key):
+ """
+ Sets and returns default axis object for dendrogram figure.
+
+ :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
+ :rtype (dict): An axis_key dictionary with set parameters.
+
+ """
+ axis_defaults = {
+ "type": "linear",
+ "ticks": "outside",
+ "mirror": "allticks",
+ "rangemode": "tozero",
+ "showticklabels": True,
+ "zeroline": False,
+ "showgrid": False,
+ "showline": True,
+ }
+
+ if len(self.labels) != 0:
+ axis_key_labels = self.xaxis
+ if self.orientation in ["left", "right"]:
+ axis_key_labels = self.yaxis
+ if axis_key_labels not in self.layout:
+ self.layout[axis_key_labels] = {}
+ self.layout[axis_key_labels]["tickvals"] = [
+ zv * self.sign[axis_key] for zv in self.zero_vals
+ ]
+ self.layout[axis_key_labels]["ticktext"] = self.labels
+ self.layout[axis_key_labels]["tickmode"] = "array"
+
+ self.layout[axis_key].update(axis_defaults)
+
+ return self.layout[axis_key]
+
+ def set_figure_layout(self, width, height):
+ """
+ Sets and returns default layout object for dendrogram figure.
+
+ """
+ self.layout.update(
+ {
+ "showlegend": False,
+ "autosize": False,
+ "hovermode": "closest",
+ "width": width,
+ "height": height,
+ }
+ )
+
+ self.set_axis_layout(self.xaxis)
+ self.set_axis_layout(self.yaxis)
+
+ return self.layout
+
+ def get_dendrogram_traces(
+ self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
+ ):
+ """
+ Calculates all the elements needed for plotting a dendrogram.
+
+ :param (ndarray) X: Matrix of observations as array of arrays
+ :param (list) colorscale: Color scale for dendrogram tree clusters
+ :param (function) distfun: Function to compute the pairwise distance
+ from the observations
+ :param (function) linkagefun: Function to compute the linkage matrix
+ from the pairwise distances
+ :param (list) hovertext: List of hovertext for constituent traces of dendrogram
+ :rtype (tuple): Contains all the traces in the following order:
+ (a) trace_list: List of Plotly trace objects for dendrogram tree
+ (b) icoord: All X points of the dendrogram tree as array of arrays
+ with length 4
+ (c) dcoord: All Y points of the dendrogram tree as array of arrays
+ with length 4
+ (d) ordered_labels: leaf labels in the order they are going to
+ appear on the plot
+ (e) P['leaves']: left-to-right traversal of the leaves
+
+ """
+ d = distfun(X)
+ Z = linkagefun(d)
+ P = sch.dendrogram(
+ Z,
+ orientation=self.orientation,
+ labels=self.labels,
+ no_plot=True,
+ color_threshold=color_threshold,
+ )
+
+ icoord = np.array(P["icoord"])
+ dcoord = np.array(P["dcoord"])
+ ordered_labels = np.array(P["ivl"])
+ color_list = np.array(P["color_list"])
+ colors = self.get_color_dict(colorscale)
+
+ trace_list = []
+
+ for i in range(len(icoord)):
+ # xs and ys are arrays of 4 points that make up the '∩' shapes
+ # of the dendrogram tree
+ if self.orientation in ["top", "bottom"]:
+ xs = icoord[i]
+ else:
+ xs = dcoord[i]
+
+ if self.orientation in ["top", "bottom"]:
+ ys = dcoord[i]
+ else:
+ ys = icoord[i]
+ color_key = color_list[i]
+ hovertext_label = None
+ if hovertext:
+ hovertext_label = hovertext[i]
+ trace = dict(
+ type="scatter",
+ x=np.multiply(self.sign[self.xaxis], xs),
+ y=np.multiply(self.sign[self.yaxis], ys),
+ mode="lines",
+ marker=dict(color=colors[color_key]),
+ text=hovertext_label,
+ hoverinfo="text",
+ )
+
+ try:
+ x_index = int(self.xaxis[-1])
+ except ValueError:
+ x_index = ""
+
+ try:
+ y_index = int(self.yaxis[-1])
+ except ValueError:
+ y_index = ""
+
+ trace["xaxis"] = f"x{x_index}"
+ trace["yaxis"] = f"y{y_index}"
+
+ trace_list.append(trace)
+
+ return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py
new file mode 100644
index 0000000..73f6609
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py
@@ -0,0 +1,441 @@
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module("numpy")
+pd = optional_imports.get_module("pandas")
+scipy = optional_imports.get_module("scipy")
+scipy_stats = optional_imports.get_module("scipy.stats")
+
+
+DEFAULT_HISTNORM = "probability density"
+ALTERNATIVE_HISTNORM = "probability"
+
+
+def validate_distplot(hist_data, curve_type):
+ """
+ Distplot-specific validations
+
+ :raises: (PlotlyError) If hist_data is not a list of lists
+ :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
+ 'normal').
+ """
+ hist_data_types = (list,)
+ if np:
+ hist_data_types += (np.ndarray,)
+ if pd:
+ hist_data_types += (pd.core.series.Series,)
+
+ if not isinstance(hist_data[0], hist_data_types):
+ raise exceptions.PlotlyError(
+ "Oops, this function was written "
+ "to handle multiple datasets, if "
+ "you want to plot just one, make "
+ "sure your hist_data variable is "
+ "still a list of lists, i.e. x = "
+ "[1, 2, 3] -> x = [[1, 2, 3]]"
+ )
+
+ curve_opts = ("kde", "normal")
+ if curve_type not in curve_opts:
+ raise exceptions.PlotlyError("curve_type must be defined as 'kde' or 'normal'")
+
+ if not scipy:
+ raise ImportError("FigureFactory.create_distplot requires scipy")
+
+
+def create_distplot(
+ hist_data,
+ group_labels,
+ bin_size=1.0,
+ curve_type="kde",
+ colors=None,
+ rug_text=None,
+ histnorm=DEFAULT_HISTNORM,
+ show_hist=True,
+ show_curve=True,
+ show_rug=True,
+):
+ """
+ Function that creates a distplot similar to seaborn.distplot;
+ **this function is deprecated**, use instead :mod:`plotly.express`
+ functions, for example
+
+ >>> import plotly.express as px
+ >>> tips = px.data.tips()
+ >>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
+ ... hover_data=tips.columns)
+ >>> fig.show()
+
+
+ The distplot can be composed of all or any combination of the following
+ 3 components: (1) histogram, (2) curve: (a) kernel density estimation
+ or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
+ (from multiple datasets) can be created in the same plot.
+
+ :param (list[list]) hist_data: Use list of lists to plot multiple data
+ sets on the same plot.
+ :param (list[str]) group_labels: Names for each data set.
+ :param (list[float]|float) bin_size: Size of histogram bins.
+ Default = 1.
+ :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
+ :param (str) histnorm: 'probability density' or 'probability'
+ Default = 'probability density'
+ :param (bool) show_hist: Add histogram to distplot? Default = True
+ :param (bool) show_curve: Add curve to distplot? Default = True
+ :param (bool) show_rug: Add rug to distplot? Default = True
+ :param (list[str]) colors: Colors for traces.
+ :param (list[list]) rug_text: Hovertext values for rug_plot,
+ :return (dict): Representation of a distplot figure.
+
+ Example 1: Simple distplot of 1 data set
+
+ >>> from plotly.figure_factory import create_distplot
+
+ >>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
+ ... 3.5, 4.1, 4.4, 4.5, 4.5,
+ ... 5.0, 5.0, 5.2, 5.5, 5.5,
+ ... 5.5, 5.5, 5.5, 6.1, 7.0]]
+ >>> group_labels = ['distplot example']
+ >>> fig = create_distplot(hist_data, group_labels)
+ >>> fig.show()
+
+
+ Example 2: Two data sets and added rug text
+
+ >>> from plotly.figure_factory import create_distplot
+ >>> # Add histogram data
+ >>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
+ ... -0.9, -0.07, 1.95, 0.9, -0.2,
+ ... -0.5, 0.3, 0.4, -0.37, 0.6]
+ >>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
+ ... 1.0, 0.8, 1.7, 0.5, 0.8,
+ ... -0.3, 1.2, 0.56, 0.3, 2.2]
+
+ >>> # Group data together
+ >>> hist_data = [hist1_x, hist2_x]
+
+ >>> group_labels = ['2012', '2013']
+
+ >>> # Add text
+ >>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
+ ... 'f1', 'g1', 'h1', 'i1', 'j1',
+ ... 'k1', 'l1', 'm1', 'n1', 'o1']
+
+ >>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
+ ... 'f2', 'g2', 'h2', 'i2', 'j2',
+ ... 'k2', 'l2', 'm2', 'n2', 'o2']
+
+ >>> # Group text together
+ >>> rug_text_all = [rug_text_1, rug_text_2]
+
+ >>> # Create distplot
+ >>> fig = create_distplot(
+ ... hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
+
+ >>> # Add title
+ >>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
+ >>> fig.show()
+
+
+ Example 3: Plot with normal curve and hide rug plot
+
+ >>> from plotly.figure_factory import create_distplot
+ >>> import numpy as np
+
+ >>> x1 = np.random.randn(190)
+ >>> x2 = np.random.randn(200)+1
+ >>> x3 = np.random.randn(200)-1
+ >>> x4 = np.random.randn(210)+2
+
+ >>> hist_data = [x1, x2, x3, x4]
+ >>> group_labels = ['2012', '2013', '2014', '2015']
+
+ >>> fig = create_distplot(
+ ... hist_data, group_labels, curve_type='normal',
+ ... show_rug=False, bin_size=.4)
+
+
+ Example 4: Distplot with Pandas
+
+ >>> from plotly.figure_factory import create_distplot
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> df = pd.DataFrame({'2012': np.random.randn(200),
+ ... '2013': np.random.randn(200)+1})
+ >>> fig = create_distplot([df[c] for c in df.columns], df.columns)
+ >>> fig.show()
+ """
+ if colors is None:
+ colors = []
+ if rug_text is None:
+ rug_text = []
+
+ validate_distplot(hist_data, curve_type)
+ utils.validate_equal_length(hist_data, group_labels)
+
+ if isinstance(bin_size, (float, int)):
+ bin_size = [bin_size] * len(hist_data)
+
+ data = []
+ if show_hist:
+ hist = _Distplot(
+ hist_data,
+ histnorm,
+ group_labels,
+ bin_size,
+ curve_type,
+ colors,
+ rug_text,
+ show_hist,
+ show_curve,
+ ).make_hist()
+
+ data.append(hist)
+
+ if show_curve:
+ if curve_type == "normal":
+ curve = _Distplot(
+ hist_data,
+ histnorm,
+ group_labels,
+ bin_size,
+ curve_type,
+ colors,
+ rug_text,
+ show_hist,
+ show_curve,
+ ).make_normal()
+ else:
+ curve = _Distplot(
+ hist_data,
+ histnorm,
+ group_labels,
+ bin_size,
+ curve_type,
+ colors,
+ rug_text,
+ show_hist,
+ show_curve,
+ ).make_kde()
+
+ data.append(curve)
+
+ if show_rug:
+ rug = _Distplot(
+ hist_data,
+ histnorm,
+ group_labels,
+ bin_size,
+ curve_type,
+ colors,
+ rug_text,
+ show_hist,
+ show_curve,
+ ).make_rug()
+
+ data.append(rug)
+ layout = graph_objs.Layout(
+ barmode="overlay",
+ hovermode="closest",
+ legend=dict(traceorder="reversed"),
+ xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
+ yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
+ yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
+ )
+ else:
+ layout = graph_objs.Layout(
+ barmode="overlay",
+ hovermode="closest",
+ legend=dict(traceorder="reversed"),
+ xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
+ yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
+ )
+
+ data = sum(data, [])
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Distplot(object):
+ """
+ Refer to TraceFactory.create_distplot() for docstring
+ """
+
+ def __init__(
+ self,
+ hist_data,
+ histnorm,
+ group_labels,
+ bin_size,
+ curve_type,
+ colors,
+ rug_text,
+ show_hist,
+ show_curve,
+ ):
+ self.hist_data = hist_data
+ self.histnorm = histnorm
+ self.group_labels = group_labels
+ self.bin_size = bin_size
+ self.show_hist = show_hist
+ self.show_curve = show_curve
+ self.trace_number = len(hist_data)
+ if rug_text:
+ self.rug_text = rug_text
+ else:
+ self.rug_text = [None] * self.trace_number
+
+ self.start = []
+ self.end = []
+ if colors:
+ self.colors = colors
+ else:
+ self.colors = [
+ "rgb(31, 119, 180)",
+ "rgb(255, 127, 14)",
+ "rgb(44, 160, 44)",
+ "rgb(214, 39, 40)",
+ "rgb(148, 103, 189)",
+ "rgb(140, 86, 75)",
+ "rgb(227, 119, 194)",
+ "rgb(127, 127, 127)",
+ "rgb(188, 189, 34)",
+ "rgb(23, 190, 207)",
+ ]
+ self.curve_x = [None] * self.trace_number
+ self.curve_y = [None] * self.trace_number
+
+ for trace in self.hist_data:
+ self.start.append(min(trace) * 1.0)
+ self.end.append(max(trace) * 1.0)
+
+ def make_hist(self):
+ """
+ Makes the histogram(s) for FigureFactory.create_distplot().
+
+ :rtype (list) hist: list of histogram representations
+ """
+ hist = [None] * self.trace_number
+
+ for index in range(self.trace_number):
+ hist[index] = dict(
+ type="histogram",
+ x=self.hist_data[index],
+ xaxis="x1",
+ yaxis="y1",
+ histnorm=self.histnorm,
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ marker=dict(color=self.colors[index % len(self.colors)]),
+ autobinx=False,
+ xbins=dict(
+ start=self.start[index],
+ end=self.end[index],
+ size=self.bin_size[index],
+ ),
+ opacity=0.7,
+ )
+ return hist
+
+ def make_kde(self):
+ """
+ Makes the kernel density estimation(s) for create_distplot().
+
+ This is called when curve_type = 'kde' in create_distplot().
+
+ :rtype (list) curve: list of kde representations
+ """
+ curve = [None] * self.trace_number
+ for index in range(self.trace_number):
+ self.curve_x[index] = [
+ self.start[index] + x * (self.end[index] - self.start[index]) / 500
+ for x in range(500)
+ ]
+ self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])(
+ self.curve_x[index]
+ )
+
+ if self.histnorm == ALTERNATIVE_HISTNORM:
+ self.curve_y[index] *= self.bin_size[index]
+
+ for index in range(self.trace_number):
+ curve[index] = dict(
+ type="scatter",
+ x=self.curve_x[index],
+ y=self.curve_y[index],
+ xaxis="x1",
+ yaxis="y1",
+ mode="lines",
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=False if self.show_hist else True,
+ marker=dict(color=self.colors[index % len(self.colors)]),
+ )
+ return curve
+
+ def make_normal(self):
+ """
+ Makes the normal curve(s) for create_distplot().
+
+ This is called when curve_type = 'normal' in create_distplot().
+
+ :rtype (list) curve: list of normal curve representations
+ """
+ curve = [None] * self.trace_number
+ mean = [None] * self.trace_number
+ sd = [None] * self.trace_number
+
+ for index in range(self.trace_number):
+ mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
+ self.curve_x[index] = [
+ self.start[index] + x * (self.end[index] - self.start[index]) / 500
+ for x in range(500)
+ ]
+ self.curve_y[index] = scipy_stats.norm.pdf(
+ self.curve_x[index], loc=mean[index], scale=sd[index]
+ )
+
+ if self.histnorm == ALTERNATIVE_HISTNORM:
+ self.curve_y[index] *= self.bin_size[index]
+
+ for index in range(self.trace_number):
+ curve[index] = dict(
+ type="scatter",
+ x=self.curve_x[index],
+ y=self.curve_y[index],
+ xaxis="x1",
+ yaxis="y1",
+ mode="lines",
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=False if self.show_hist else True,
+ marker=dict(color=self.colors[index % len(self.colors)]),
+ )
+ return curve
+
+ def make_rug(self):
+ """
+ Makes the rug plot(s) for create_distplot().
+
+ :rtype (list) rug: list of rug plot representations
+ """
+ rug = [None] * self.trace_number
+ for index in range(self.trace_number):
+ rug[index] = dict(
+ type="scatter",
+ x=self.hist_data[index],
+ y=([self.group_labels[index]] * len(self.hist_data[index])),
+ xaxis="x1",
+ yaxis="y2",
+ mode="markers",
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=(False if self.show_hist or self.show_curve else True),
+ text=self.rug_text[index],
+ marker=dict(
+ color=self.colors[index % len(self.colors)], symbol="line-ns-open"
+ ),
+ )
+ return rug
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py
new file mode 100644
index 0000000..06dc71d
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py
@@ -0,0 +1,1195 @@
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.figure_factory import utils
+from plotly.subplots import make_subplots
+
+import math
+from numbers import Number
+
+pd = optional_imports.get_module("pandas")
+
+TICK_COLOR = "#969696"
+AXIS_TITLE_COLOR = "#0f0f0f"
+AXIS_TITLE_SIZE = 12
+GRID_COLOR = "#ffffff"
+LEGEND_COLOR = "#efefef"
+PLOT_BGCOLOR = "#ededed"
+ANNOT_RECT_COLOR = "#d0d0d0"
+LEGEND_BORDER_WIDTH = 1
+LEGEND_ANNOT_X = 1.05
+LEGEND_ANNOT_Y = 0.5
+MAX_TICKS_PER_AXIS = 5
+THRES_FOR_FLIPPED_FACET_TITLES = 10
+GRID_WIDTH = 1
+
+VALID_TRACE_TYPES = ["scatter", "scattergl", "histogram", "bar", "box"]
+
+CUSTOM_LABEL_ERROR = (
+ "If you are using a dictionary for custom labels for the facet row/col, "
+ "make sure each key in that column of the dataframe is in your facet "
+ "labels. The keys you need are {}"
+)
+
+
+def _is_flipped(num):
+ if num >= THRES_FOR_FLIPPED_FACET_TITLES:
+ flipped = True
+ else:
+ flipped = False
+ return flipped
+
+
+def _return_label(original_label, facet_labels, facet_var):
+ if isinstance(facet_labels, dict):
+ label = facet_labels[original_label]
+ elif isinstance(facet_labels, str):
+ label = "{}: {}".format(facet_var, original_label)
+ else:
+ label = original_label
+ return label
+
+
+def _legend_annotation(color_name):
+ legend_title = dict(
+ textangle=0,
+ xanchor="left",
+ yanchor="middle",
+ x=LEGEND_ANNOT_X,
+ y=1.03,
+ showarrow=False,
+ xref="paper",
+ yref="paper",
+ text="factor({})".format(color_name),
+ font=dict(size=13, color="#000000"),
+ )
+ return legend_title
+
+
+def _annotation_dict(
+ text, lane, num_of_lanes, SUBPLOT_SPACING, row_col="col", flipped=True
+):
+ temp = (1 - (num_of_lanes - 1) * SUBPLOT_SPACING) / (num_of_lanes)
+ if not flipped:
+ xanchor = "center"
+ yanchor = "middle"
+ if row_col == "col":
+ x = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp
+ y = 1.03
+ textangle = 0
+ elif row_col == "row":
+ y = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp
+ x = 1.03
+ textangle = 90
+ else:
+ if row_col == "col":
+ xanchor = "center"
+ yanchor = "bottom"
+ x = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp
+ y = 1.0
+ textangle = 270
+ elif row_col == "row":
+ xanchor = "left"
+ yanchor = "middle"
+ y = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp
+ x = 1.0
+ textangle = 0
+
+ annotation_dict = dict(
+ textangle=textangle,
+ xanchor=xanchor,
+ yanchor=yanchor,
+ x=x,
+ y=y,
+ showarrow=False,
+ xref="paper",
+ yref="paper",
+ text=str(text),
+ font=dict(size=13, color=AXIS_TITLE_COLOR),
+ )
+ return annotation_dict
+
+
+def _axis_title_annotation(text, x_or_y_axis):
+ if x_or_y_axis == "x":
+ x_pos = 0.5
+ y_pos = -0.1
+ textangle = 0
+ elif x_or_y_axis == "y":
+ x_pos = -0.1
+ y_pos = 0.5
+ textangle = 270
+
+ if not text:
+ text = ""
+
+ annot = {
+ "font": {"color": "#000000", "size": AXIS_TITLE_SIZE},
+ "showarrow": False,
+ "text": text,
+ "textangle": textangle,
+ "x": x_pos,
+ "xanchor": "center",
+ "xref": "paper",
+ "y": y_pos,
+ "yanchor": "middle",
+ "yref": "paper",
+ }
+ return annot
+
+
+def _add_shapes_to_fig(fig, annot_rect_color, flipped_rows=False, flipped_cols=False):
+ shapes_list = []
+ for key in fig["layout"].to_plotly_json().keys():
+ if "axis" in key and fig["layout"][key]["domain"] != [0.0, 1.0]:
+ shape = {
+ "fillcolor": annot_rect_color,
+ "layer": "below",
+ "line": {"color": annot_rect_color, "width": 1},
+ "type": "rect",
+ "xref": "paper",
+ "yref": "paper",
+ }
+
+ if "xaxis" in key:
+ shape["x0"] = fig["layout"][key]["domain"][0]
+ shape["x1"] = fig["layout"][key]["domain"][1]
+ shape["y0"] = 1.005
+ shape["y1"] = 1.05
+
+ if flipped_cols:
+ shape["y1"] += 0.5
+ shapes_list.append(shape)
+
+ elif "yaxis" in key:
+ shape["x0"] = 1.005
+ shape["x1"] = 1.05
+ shape["y0"] = fig["layout"][key]["domain"][0]
+ shape["y1"] = fig["layout"][key]["domain"][1]
+
+ if flipped_rows:
+ shape["x1"] += 1
+ shapes_list.append(shape)
+
+ fig["layout"]["shapes"] = shapes_list
+
+
+def _make_trace_for_scatter(trace, trace_type, color, **kwargs_marker):
+ if trace_type in ["scatter", "scattergl"]:
+ trace["mode"] = "markers"
+ trace["marker"] = dict(color=color, **kwargs_marker)
+ return trace
+
+
+def _facet_grid_color_categorical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colormap,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+):
+ fig = make_subplots(
+ rows=num_of_rows,
+ cols=num_of_cols,
+ shared_xaxes=True,
+ shared_yaxes=True,
+ horizontal_spacing=SUBPLOT_SPACING,
+ vertical_spacing=SUBPLOT_SPACING,
+ print_grid=False,
+ )
+
+ annotations = []
+ if not facet_row and not facet_col:
+ color_groups = list(df.groupby(color_name))
+ for group in color_groups:
+ trace = dict(
+ type=trace_type,
+ name=group[0],
+ marker=dict(color=colormap[group[0]]),
+ **kwargs_trace,
+ )
+ if x:
+ trace["x"] = group[1][x]
+ if y:
+ trace["y"] = group[1][y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, colormap[group[0]], **kwargs_marker
+ )
+
+ fig.append_trace(trace, 1, 1)
+
+ elif (facet_row and not facet_col) or (not facet_row and facet_col):
+ groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col))
+ for j, group in enumerate(groups_by_facet):
+ for color_val in df[color_name].unique():
+ data_by_color = group[1][group[1][color_name] == color_val]
+ trace = dict(
+ type=trace_type,
+ name=color_val,
+ marker=dict(color=colormap[color_val]),
+ **kwargs_trace,
+ )
+ if x:
+ trace["x"] = data_by_color[x]
+ if y:
+ trace["y"] = data_by_color[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, colormap[color_val], **kwargs_marker
+ )
+
+ fig.append_trace(
+ trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1
+ )
+
+ label = _return_label(
+ group[0],
+ facet_row_labels if facet_row else facet_col_labels,
+ facet_row if facet_row else facet_col,
+ )
+
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - j if facet_row else j + 1,
+ num_of_rows if facet_row else num_of_cols,
+ SUBPLOT_SPACING,
+ "row" if facet_row else "col",
+ flipped_rows,
+ )
+ )
+
+ elif facet_row and facet_col:
+ groups_by_facets = list(df.groupby([facet_row, facet_col]))
+ tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets}
+
+ row_values = df[facet_row].unique()
+ col_values = df[facet_col].unique()
+ color_vals = df[color_name].unique()
+ for row_count, x_val in enumerate(row_values):
+ for col_count, y_val in enumerate(col_values):
+ try:
+ group = tuple_to_facet_group[(x_val, y_val)]
+ except KeyError:
+ group = pd.DataFrame(
+ [[None, None, None]], columns=[x, y, color_name]
+ )
+
+ for color_val in color_vals:
+ if group.values.tolist() != [[None, None, None]]:
+ group_filtered = group[group[color_name] == color_val]
+
+ trace = dict(
+ type=trace_type,
+ name=color_val,
+ marker=dict(color=colormap[color_val]),
+ **kwargs_trace,
+ )
+ new_x = group_filtered[x]
+ new_y = group_filtered[y]
+ else:
+ trace = dict(
+ type=trace_type,
+ name=color_val,
+ marker=dict(color=colormap[color_val]),
+ showlegend=False,
+ **kwargs_trace,
+ )
+ new_x = group[x]
+ new_y = group[y]
+
+ if x:
+ trace["x"] = new_x
+ if y:
+ trace["y"] = new_y
+ trace = _make_trace_for_scatter(
+ trace, trace_type, colormap[color_val], **kwargs_marker
+ )
+
+ fig.append_trace(trace, row_count + 1, col_count + 1)
+ if row_count == 0:
+ label = _return_label(
+ col_values[col_count], facet_col_labels, facet_col
+ )
+ annotations.append(
+ _annotation_dict(
+ label,
+ col_count + 1,
+ num_of_cols,
+ SUBPLOT_SPACING,
+ row_col="col",
+ flipped=flipped_cols,
+ )
+ )
+ label = _return_label(row_values[row_count], facet_row_labels, facet_row)
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - row_count,
+ num_of_rows,
+ SUBPLOT_SPACING,
+ row_col="row",
+ flipped=flipped_rows,
+ )
+ )
+
+ return fig, annotations
+
+
+def _facet_grid_color_numerical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colormap,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+):
+ fig = make_subplots(
+ rows=num_of_rows,
+ cols=num_of_cols,
+ shared_xaxes=True,
+ shared_yaxes=True,
+ horizontal_spacing=SUBPLOT_SPACING,
+ vertical_spacing=SUBPLOT_SPACING,
+ print_grid=False,
+ )
+
+ annotations = []
+ if not facet_row and not facet_col:
+ trace = dict(
+ type=trace_type,
+ marker=dict(color=df[color_name], colorscale=colormap, showscale=True),
+ **kwargs_trace,
+ )
+ if x:
+ trace["x"] = df[x]
+ if y:
+ trace["y"] = df[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, df[color_name], **kwargs_marker
+ )
+
+ fig.append_trace(trace, 1, 1)
+
+ if (facet_row and not facet_col) or (not facet_row and facet_col):
+ groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col))
+ for j, group in enumerate(groups_by_facet):
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=df[color_name],
+ colorscale=colormap,
+ showscale=True,
+ colorbar=dict(x=1.15),
+ ),
+ **kwargs_trace,
+ )
+ if x:
+ trace["x"] = group[1][x]
+ if y:
+ trace["y"] = group[1][y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, df[color_name], **kwargs_marker
+ )
+
+ fig.append_trace(
+ trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1
+ )
+
+ labels = facet_row_labels if facet_row else facet_col_labels
+ label = _return_label(
+ group[0], labels, facet_row if facet_row else facet_col
+ )
+
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - j if facet_row else j + 1,
+ num_of_rows if facet_row else num_of_cols,
+ SUBPLOT_SPACING,
+ "row" if facet_row else "col",
+ flipped=flipped_rows,
+ )
+ )
+
+ elif facet_row and facet_col:
+ groups_by_facets = list(df.groupby([facet_row, facet_col]))
+ tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets}
+
+ row_values = df[facet_row].unique()
+ col_values = df[facet_col].unique()
+ for row_count, x_val in enumerate(row_values):
+ for col_count, y_val in enumerate(col_values):
+ try:
+ group = tuple_to_facet_group[(x_val, y_val)]
+ except KeyError:
+ group = pd.DataFrame(
+ [[None, None, None]], columns=[x, y, color_name]
+ )
+
+ if group.values.tolist() != [[None, None, None]]:
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=df[color_name],
+ colorscale=colormap,
+ showscale=(row_count == 0),
+ colorbar=dict(x=1.15),
+ ),
+ **kwargs_trace,
+ )
+
+ else:
+ trace = dict(type=trace_type, showlegend=False, **kwargs_trace)
+
+ if x:
+ trace["x"] = group[x]
+ if y:
+ trace["y"] = group[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, df[color_name], **kwargs_marker
+ )
+
+ fig.append_trace(trace, row_count + 1, col_count + 1)
+ if row_count == 0:
+ label = _return_label(
+ col_values[col_count], facet_col_labels, facet_col
+ )
+ annotations.append(
+ _annotation_dict(
+ label,
+ col_count + 1,
+ num_of_cols,
+ SUBPLOT_SPACING,
+ row_col="col",
+ flipped=flipped_cols,
+ )
+ )
+ label = _return_label(row_values[row_count], facet_row_labels, facet_row)
+ annotations.append(
+ _annotation_dict(
+ row_values[row_count],
+ num_of_rows - row_count,
+ num_of_rows,
+ SUBPLOT_SPACING,
+ row_col="row",
+ flipped=flipped_rows,
+ )
+ )
+
+ return fig, annotations
+
+
+def _facet_grid(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+):
+ fig = make_subplots(
+ rows=num_of_rows,
+ cols=num_of_cols,
+ shared_xaxes=True,
+ shared_yaxes=True,
+ horizontal_spacing=SUBPLOT_SPACING,
+ vertical_spacing=SUBPLOT_SPACING,
+ print_grid=False,
+ )
+ annotations = []
+ if not facet_row and not facet_col:
+ trace = dict(
+ type=trace_type,
+ marker=dict(color=marker_color, line=kwargs_marker["line"]),
+ **kwargs_trace,
+ )
+
+ if x:
+ trace["x"] = df[x]
+ if y:
+ trace["y"] = df[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, marker_color, **kwargs_marker
+ )
+
+ fig.append_trace(trace, 1, 1)
+
+ elif (facet_row and not facet_col) or (not facet_row and facet_col):
+ groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col))
+ for j, group in enumerate(groups_by_facet):
+ trace = dict(
+ type=trace_type,
+ marker=dict(color=marker_color, line=kwargs_marker["line"]),
+ **kwargs_trace,
+ )
+
+ if x:
+ trace["x"] = group[1][x]
+ if y:
+ trace["y"] = group[1][y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, marker_color, **kwargs_marker
+ )
+
+ fig.append_trace(
+ trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1
+ )
+
+ label = _return_label(
+ group[0],
+ facet_row_labels if facet_row else facet_col_labels,
+ facet_row if facet_row else facet_col,
+ )
+
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - j if facet_row else j + 1,
+ num_of_rows if facet_row else num_of_cols,
+ SUBPLOT_SPACING,
+ "row" if facet_row else "col",
+ flipped_rows,
+ )
+ )
+
+ elif facet_row and facet_col:
+ groups_by_facets = list(df.groupby([facet_row, facet_col]))
+ tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets}
+
+ row_values = df[facet_row].unique()
+ col_values = df[facet_col].unique()
+ for row_count, x_val in enumerate(row_values):
+ for col_count, y_val in enumerate(col_values):
+ try:
+ group = tuple_to_facet_group[(x_val, y_val)]
+ except KeyError:
+ group = pd.DataFrame([[None, None]], columns=[x, y])
+ trace = dict(
+ type=trace_type,
+ marker=dict(color=marker_color, line=kwargs_marker["line"]),
+ **kwargs_trace,
+ )
+ if x:
+ trace["x"] = group[x]
+ if y:
+ trace["y"] = group[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, marker_color, **kwargs_marker
+ )
+
+ fig.append_trace(trace, row_count + 1, col_count + 1)
+ if row_count == 0:
+ label = _return_label(
+ col_values[col_count], facet_col_labels, facet_col
+ )
+ annotations.append(
+ _annotation_dict(
+ label,
+ col_count + 1,
+ num_of_cols,
+ SUBPLOT_SPACING,
+ row_col="col",
+ flipped=flipped_cols,
+ )
+ )
+
+ label = _return_label(row_values[row_count], facet_row_labels, facet_row)
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - row_count,
+ num_of_rows,
+ SUBPLOT_SPACING,
+ row_col="row",
+ flipped=flipped_rows,
+ )
+ )
+
+ return fig, annotations
+
+
+def create_facet_grid(
+ df,
+ x=None,
+ y=None,
+ facet_row=None,
+ facet_col=None,
+ color_name=None,
+ colormap=None,
+ color_is_cat=False,
+ facet_row_labels=None,
+ facet_col_labels=None,
+ height=None,
+ width=None,
+ trace_type="scatter",
+ scales="fixed",
+ dtick_x=None,
+ dtick_y=None,
+ show_boxes=True,
+ ggplot2=False,
+ binsize=1,
+ **kwargs,
+):
+ """
+ Returns figure for facet grid; **this function is deprecated**, since
+ plotly.express functions should be used instead, for example
+
+ >>> import plotly.express as px
+ >>> tips = px.data.tips()
+ >>> fig = px.scatter(tips,
+ ... x='total_bill',
+ ... y='tip',
+ ... facet_row='sex',
+ ... facet_col='smoker',
+ ... color='size')
+
+
+ :param (pd.DataFrame) df: the dataframe of columns for the facet grid.
+ :param (str) x: the name of the dataframe column for the x axis data.
+ :param (str) y: the name of the dataframe column for the y axis data.
+ :param (str) facet_row: the name of the dataframe column that is used to
+ facet the grid into row panels.
+ :param (str) facet_col: the name of the dataframe column that is used to
+ facet the grid into column panels.
+ :param (str) color_name: the name of your dataframe column that will
+ function as the colormap variable.
+ :param (str|list|dict) colormap: the param that determines how the
+ color_name column colors the data. If the dataframe contains numeric
+ data, then a dictionary of colors will group the data categorically
+ while a Plotly Colorscale name or a custom colorscale will treat it
+ numerically. To learn more about colors and types of colormap, run
+ `help(plotly.colors)`.
+ :param (bool) color_is_cat: determines whether a numerical column for the
+ colormap will be treated as categorical (True) or sequential (False).
+ Default = False.
+ :param (str|dict) facet_row_labels: set to either 'name' or a dictionary
+ of all the unique values in the faceting row mapped to some text to
+ show up in the label annotations. If None, labeling works like usual.
+ :param (str|dict) facet_col_labels: set to either 'name' or a dictionary
+ of all the values in the faceting row mapped to some text to show up
+ in the label annotations. If None, labeling works like usual.
+ :param (int) height: the height of the facet grid figure.
+ :param (int) width: the width of the facet grid figure.
+ :param (str) trace_type: decides the type of plot to appear in the
+ facet grid. The options are 'scatter', 'scattergl', 'histogram',
+ 'bar', and 'box'.
+ Default = 'scatter'.
+ :param (str) scales: determines if axes have fixed ranges or not. Valid
+ settings are 'fixed' (all axes fixed), 'free_x' (x axis free only),
+ 'free_y' (y axis free only) or 'free' (both axes free).
+ :param (float) dtick_x: determines the distance between each tick on the
+ x-axis. Default is None which means dtick_x is set automatically.
+ :param (float) dtick_y: determines the distance between each tick on the
+ y-axis. Default is None which means dtick_y is set automatically.
+ :param (bool) show_boxes: draws grey boxes behind the facet titles.
+ :param (bool) ggplot2: draws the facet grid in the style of `ggplot2`. See
+ http://ggplot2.tidyverse.org/reference/facet_grid.html for reference.
+ Default = False
+ :param (int) binsize: groups all data into bins of a given length.
+ :param (dict) kwargs: a dictionary of scatterplot arguments.
+
+ Examples 1: One Way Faceting
+
+ >>> import plotly.figure_factory as ff
+ >>> import pandas as pd
+ >>> mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')
+
+ >>> fig = ff.create_facet_grid(
+ ... mpg,
+ ... x='displ',
+ ... y='cty',
+ ... facet_col='cyl',
+ ... )
+ >>> fig.show()
+
+ Example 2: Two Way Faceting
+
+ >>> import plotly.figure_factory as ff
+
+ >>> import pandas as pd
+
+ >>> mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')
+
+ >>> fig = ff.create_facet_grid(
+ ... mpg,
+ ... x='displ',
+ ... y='cty',
+ ... facet_row='drv',
+ ... facet_col='cyl',
+ ... )
+ >>> fig.show()
+
+ Example 3: Categorical Coloring
+
+ >>> import plotly.figure_factory as ff
+ >>> import pandas as pd
+ >>> mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv')
+ >>> mtcars.cyl = mtcars.cyl.astype(str)
+ >>> fig = ff.create_facet_grid(
+ ... mtcars,
+ ... x='mpg',
+ ... y='wt',
+ ... facet_col='cyl',
+ ... color_name='cyl',
+ ... color_is_cat=True,
+ ... )
+ >>> fig.show()
+
+
+ """
+ if not pd:
+ raise ImportError("'pandas' must be installed for this figure_factory.")
+
+ if not isinstance(df, pd.DataFrame):
+ raise exceptions.PlotlyError("You must input a pandas DataFrame.")
+
+ # make sure all columns are of homogenous datatype
+ utils.validate_dataframe(df)
+
+ if trace_type in ["scatter", "scattergl"]:
+ if not x or not y:
+ raise exceptions.PlotlyError(
+ "You need to input 'x' and 'y' if you are you are using a "
+ "trace_type of 'scatter' or 'scattergl'."
+ )
+
+ for key in [x, y, facet_row, facet_col, color_name]:
+ if key is not None:
+ try:
+ df[key]
+ except KeyError:
+ raise exceptions.PlotlyError(
+ "x, y, facet_row, facet_col and color_name must be keys "
+ "in your dataframe."
+ )
+ # autoscale histogram bars
+ if trace_type not in ["scatter", "scattergl"]:
+ scales = "free"
+
+ # validate scales
+ if scales not in ["fixed", "free_x", "free_y", "free"]:
+ raise exceptions.PlotlyError(
+ "'scales' must be set to 'fixed', 'free_x', 'free_y' and 'free'."
+ )
+
+ if trace_type not in VALID_TRACE_TYPES:
+ raise exceptions.PlotlyError(
+ "'trace_type' must be in {}".format(VALID_TRACE_TYPES)
+ )
+
+ if trace_type == "histogram":
+ SUBPLOT_SPACING = 0.06
+ else:
+ SUBPLOT_SPACING = 0.015
+
+ # seperate kwargs for marker and else
+ if "marker" in kwargs:
+ kwargs_marker = kwargs["marker"]
+ else:
+ kwargs_marker = {}
+ marker_color = kwargs_marker.pop("color", None)
+ kwargs.pop("marker", None)
+ kwargs_trace = kwargs
+
+ if "size" not in kwargs_marker:
+ if ggplot2:
+ kwargs_marker["size"] = 5
+ else:
+ kwargs_marker["size"] = 8
+
+ if "opacity" not in kwargs_marker:
+ if not ggplot2:
+ kwargs_trace["opacity"] = 0.6
+
+ if "line" not in kwargs_marker:
+ if not ggplot2:
+ kwargs_marker["line"] = {"color": "darkgrey", "width": 1}
+ else:
+ kwargs_marker["line"] = {}
+
+ # default marker size
+ if not ggplot2:
+ if not marker_color:
+ marker_color = "rgb(31, 119, 180)"
+ else:
+ marker_color = "rgb(0, 0, 0)"
+
+ num_of_rows = 1
+ num_of_cols = 1
+ flipped_rows = False
+ flipped_cols = False
+ if facet_row:
+ num_of_rows = len(df[facet_row].unique())
+ flipped_rows = _is_flipped(num_of_rows)
+ if isinstance(facet_row_labels, dict):
+ for key in df[facet_row].unique():
+ if key not in facet_row_labels.keys():
+ unique_keys = df[facet_row].unique().tolist()
+ raise exceptions.PlotlyError(CUSTOM_LABEL_ERROR.format(unique_keys))
+ if facet_col:
+ num_of_cols = len(df[facet_col].unique())
+ flipped_cols = _is_flipped(num_of_cols)
+ if isinstance(facet_col_labels, dict):
+ for key in df[facet_col].unique():
+ if key not in facet_col_labels.keys():
+ unique_keys = df[facet_col].unique().tolist()
+ raise exceptions.PlotlyError(CUSTOM_LABEL_ERROR.format(unique_keys))
+ show_legend = False
+ if color_name:
+ if isinstance(df[color_name].iloc[0], str) or color_is_cat:
+ show_legend = True
+ if isinstance(colormap, dict):
+ clrs.validate_colors_dict(colormap, "rgb")
+
+ for val in df[color_name].unique():
+ if val not in colormap.keys():
+ raise exceptions.PlotlyError(
+ "If using 'colormap' as a dictionary, make sure "
+ "all the values of the colormap column are in "
+ "the keys of your dictionary."
+ )
+ else:
+ # use default plotly colors for dictionary
+ default_colors = clrs.DEFAULT_PLOTLY_COLORS
+ colormap = {}
+ j = 0
+ for val in df[color_name].unique():
+ if j >= len(default_colors):
+ j = 0
+ colormap[val] = default_colors[j]
+ j += 1
+ fig, annotations = _facet_grid_color_categorical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colormap,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+ )
+
+ elif isinstance(df[color_name].iloc[0], Number):
+ if isinstance(colormap, dict):
+ show_legend = True
+ clrs.validate_colors_dict(colormap, "rgb")
+
+ for val in df[color_name].unique():
+ if val not in colormap.keys():
+ raise exceptions.PlotlyError(
+ "If using 'colormap' as a dictionary, make sure "
+ "all the values of the colormap column are in "
+ "the keys of your dictionary."
+ )
+ fig, annotations = _facet_grid_color_categorical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colormap,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+ )
+
+ elif isinstance(colormap, list):
+ colorscale_list = colormap
+ clrs.validate_colorscale(colorscale_list)
+
+ fig, annotations = _facet_grid_color_numerical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colorscale_list,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+ )
+ elif isinstance(colormap, str):
+ if colormap in clrs.PLOTLY_SCALES.keys():
+ colorscale_list = clrs.PLOTLY_SCALES[colormap]
+ else:
+ raise exceptions.PlotlyError(
+ "If 'colormap' is a string, it must be the name "
+ "of a Plotly Colorscale. The available colorscale "
+ "names are {}".format(clrs.PLOTLY_SCALES.keys())
+ )
+ fig, annotations = _facet_grid_color_numerical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colorscale_list,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+ )
+ else:
+ colorscale_list = clrs.PLOTLY_SCALES["Reds"]
+ fig, annotations = _facet_grid_color_numerical(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ color_name,
+ colorscale_list,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+ )
+
+ else:
+ fig, annotations = _facet_grid(
+ df,
+ x,
+ y,
+ facet_row,
+ facet_col,
+ num_of_rows,
+ num_of_cols,
+ facet_row_labels,
+ facet_col_labels,
+ trace_type,
+ flipped_rows,
+ flipped_cols,
+ show_boxes,
+ SUBPLOT_SPACING,
+ marker_color,
+ kwargs_trace,
+ kwargs_marker,
+ )
+
+ if not height:
+ height = max(600, 100 * num_of_rows)
+ if not width:
+ width = max(600, 100 * num_of_cols)
+
+ fig["layout"].update(
+ height=height, width=width, title="", paper_bgcolor="rgb(251, 251, 251)"
+ )
+ if ggplot2:
+ fig["layout"].update(
+ plot_bgcolor=PLOT_BGCOLOR,
+ paper_bgcolor="rgb(255, 255, 255)",
+ hovermode="closest",
+ )
+
+ # axis titles
+ x_title_annot = _axis_title_annotation(x, "x")
+ y_title_annot = _axis_title_annotation(y, "y")
+
+ # annotations
+ annotations.append(x_title_annot)
+ annotations.append(y_title_annot)
+
+ # legend
+ fig["layout"]["showlegend"] = show_legend
+ fig["layout"]["legend"]["bgcolor"] = LEGEND_COLOR
+ fig["layout"]["legend"]["borderwidth"] = LEGEND_BORDER_WIDTH
+ fig["layout"]["legend"]["x"] = 1.05
+ fig["layout"]["legend"]["y"] = 1
+ fig["layout"]["legend"]["yanchor"] = "top"
+
+ if show_legend:
+ fig["layout"]["showlegend"] = show_legend
+ if ggplot2:
+ if color_name:
+ legend_annot = _legend_annotation(color_name)
+ annotations.append(legend_annot)
+ fig["layout"]["margin"]["r"] = 150
+
+ # assign annotations to figure
+ fig["layout"]["annotations"] = annotations
+
+ # add shaded boxes behind axis titles
+ if show_boxes and ggplot2:
+ _add_shapes_to_fig(fig, ANNOT_RECT_COLOR, flipped_rows, flipped_cols)
+
+ # all xaxis and yaxis labels
+ axis_labels = {"x": [], "y": []}
+ for key in fig["layout"]:
+ if "xaxis" in key:
+ axis_labels["x"].append(key)
+ elif "yaxis" in key:
+ axis_labels["y"].append(key)
+
+ string_number_in_data = False
+ for var in [v for v in [x, y] if v]:
+ if isinstance(df[var].tolist()[0], str):
+ for item in df[var]:
+ try:
+ int(item)
+ string_number_in_data = True
+ except ValueError:
+ pass
+
+ if string_number_in_data:
+ for x_y in axis_labels.keys():
+ for axis_name in axis_labels[x_y]:
+ fig["layout"][axis_name]["type"] = "category"
+
+ if scales == "fixed":
+ fixed_axes = ["x", "y"]
+ elif scales == "free_x":
+ fixed_axes = ["y"]
+ elif scales == "free_y":
+ fixed_axes = ["x"]
+ elif scales == "free":
+ fixed_axes = []
+
+ # fixed ranges
+ for x_y in fixed_axes:
+ min_ranges = []
+ max_ranges = []
+ for trace in fig["data"]:
+ if trace[x_y] is not None and len(trace[x_y]) > 0:
+ min_ranges.append(min(trace[x_y]))
+ max_ranges.append(max(trace[x_y]))
+ while None in min_ranges:
+ min_ranges.remove(None)
+ while None in max_ranges:
+ max_ranges.remove(None)
+
+ min_range = min(min_ranges)
+ max_range = max(max_ranges)
+
+ range_are_numbers = isinstance(min_range, Number) and isinstance(
+ max_range, Number
+ )
+
+ if range_are_numbers:
+ min_range = math.floor(min_range)
+ max_range = math.ceil(max_range)
+
+ # extend widen frame by 5% on each side
+ min_range -= 0.05 * (max_range - min_range)
+ max_range += 0.05 * (max_range - min_range)
+
+ if x_y == "x":
+ if dtick_x:
+ dtick = dtick_x
+ else:
+ dtick = math.floor((max_range - min_range) / MAX_TICKS_PER_AXIS)
+ elif x_y == "y":
+ if dtick_y:
+ dtick = dtick_y
+ else:
+ dtick = math.floor((max_range - min_range) / MAX_TICKS_PER_AXIS)
+ else:
+ dtick = 1
+
+ for axis_title in axis_labels[x_y]:
+ fig["layout"][axis_title]["dtick"] = dtick
+ fig["layout"][axis_title]["ticklen"] = 0
+ fig["layout"][axis_title]["zeroline"] = False
+ if ggplot2:
+ fig["layout"][axis_title]["tickwidth"] = 1
+ fig["layout"][axis_title]["ticklen"] = 4
+ fig["layout"][axis_title]["gridwidth"] = GRID_WIDTH
+
+ fig["layout"][axis_title]["gridcolor"] = GRID_COLOR
+ fig["layout"][axis_title]["gridwidth"] = 2
+ fig["layout"][axis_title]["tickfont"] = {
+ "color": TICK_COLOR,
+ "size": 10,
+ }
+
+ # insert ranges into fig
+ if x_y in fixed_axes:
+ for key in fig["layout"]:
+ if "{}axis".format(x_y) in key and range_are_numbers:
+ fig["layout"][key]["range"] = [min_range, max_range]
+
+ return fig
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py
new file mode 100644
index 0000000..2fe393f
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py
@@ -0,0 +1,1034 @@
+from numbers import Number
+
+import copy
+
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.figure_factory import utils
+import plotly.graph_objects as go
+
+pd = optional_imports.get_module("pandas")
+
+REQUIRED_GANTT_KEYS = ["Task", "Start", "Finish"]
+
+
+def _get_corner_points(x0, y0, x1, y1):
+ """
+ Returns the corner points of a scatter rectangle
+
+ :param x0: x-start
+ :param y0: y-lower
+ :param x1: x-end
+ :param y1: y-upper
+ :return: ([x], [y]), tuple of lists containing the x and y values
+ """
+
+ return ([x0, x1, x1, x0], [y0, y0, y1, y1])
+
+
+def validate_gantt(df):
+ """
+ Validates the inputted dataframe or list
+ """
+ if pd and isinstance(df, pd.core.frame.DataFrame):
+ # validate that df has all the required keys
+ for key in REQUIRED_GANTT_KEYS:
+ if key not in df:
+ raise exceptions.PlotlyError(
+ "The columns in your dataframe must include the "
+ "following keys: {0}".format(", ".join(REQUIRED_GANTT_KEYS))
+ )
+
+ num_of_rows = len(df.index)
+ chart = []
+ for index in range(num_of_rows):
+ task_dict = {}
+ for key in df:
+ task_dict[key] = df.iloc[index][key]
+ chart.append(task_dict)
+
+ return chart
+
+ # validate if df is a list
+ if not isinstance(df, list):
+ raise exceptions.PlotlyError(
+ "You must input either a dataframe or a list of dictionaries."
+ )
+
+ # validate if df is empty
+ if len(df) <= 0:
+ raise exceptions.PlotlyError(
+ "Your list is empty. It must contain at least one dictionary."
+ )
+ if not isinstance(df[0], dict):
+ raise exceptions.PlotlyError("Your list must only include dictionaries.")
+ return df
+
+
+def gantt(
+ chart,
+ colors,
+ title,
+ bar_width,
+ showgrid_x,
+ showgrid_y,
+ height,
+ width,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=False,
+ show_hover_fill=True,
+ show_colorbar=True,
+):
+ """
+ Refer to create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+
+ for index in range(len(chart)):
+ task = dict(
+ x0=chart[index]["Start"],
+ x1=chart[index]["Finish"],
+ name=chart[index]["Task"],
+ )
+ if "Description" in chart[index]:
+ task["description"] = chart[index]["Description"]
+ tasks.append(task)
+
+ # create a scatter trace for every task group
+ scatter_data_dict = dict()
+ marker_data_dict = dict()
+
+ if show_hover_fill:
+ hoverinfo = "name"
+ else:
+ hoverinfo = "skip"
+
+ scatter_data_template = {
+ "x": [],
+ "y": [],
+ "mode": "none",
+ "fill": "toself",
+ "hoverinfo": hoverinfo,
+ }
+
+ marker_data_template = {
+ "x": [],
+ "y": [],
+ "mode": "markers",
+ "text": [],
+ "marker": dict(color="", size=1, opacity=0),
+ "name": "",
+ "showlegend": False,
+ }
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted first
+ # are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ color_index = 0
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ del tasks[index]["name"]
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]["y0"] = groupID - bar_width
+ tasks[index]["y1"] = groupID + bar_width
+
+ # check if colors need to be looped
+ if color_index >= len(colors):
+ color_index = 0
+ tasks[index]["fillcolor"] = colors[color_index]
+ color_id = tasks[index]["fillcolor"]
+
+ if color_id not in scatter_data_dict:
+ scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template)
+
+ scatter_data_dict[color_id]["fillcolor"] = color_id
+ scatter_data_dict[color_id]["name"] = str(tn)
+ scatter_data_dict[color_id]["legendgroup"] = color_id
+
+ # if there are already values append the gap
+ if len(scatter_data_dict[color_id]["x"]) > 0:
+ # a gap on the scatterplot separates the rectangles from each other
+ scatter_data_dict[color_id]["x"].append(
+ scatter_data_dict[color_id]["x"][-1]
+ )
+ scatter_data_dict[color_id]["y"].append(None)
+
+ xs, ys = _get_corner_points(
+ tasks[index]["x0"],
+ tasks[index]["y0"],
+ tasks[index]["x1"],
+ tasks[index]["y1"],
+ )
+
+ scatter_data_dict[color_id]["x"] += xs
+ scatter_data_dict[color_id]["y"] += ys
+
+ # append dummy markers for showing start and end of interval
+ if color_id not in marker_data_dict:
+ marker_data_dict[color_id] = copy.deepcopy(marker_data_template)
+ marker_data_dict[color_id]["marker"]["color"] = color_id
+ marker_data_dict[color_id]["legendgroup"] = color_id
+
+ marker_data_dict[color_id]["x"].append(tasks[index]["x0"])
+ marker_data_dict[color_id]["x"].append(tasks[index]["x1"])
+ marker_data_dict[color_id]["y"].append(groupID)
+ marker_data_dict[color_id]["y"].append(groupID)
+
+ if "description" in tasks[index]:
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ del tasks[index]["description"]
+ else:
+ marker_data_dict[color_id]["text"].append(None)
+ marker_data_dict[color_id]["text"].append(None)
+
+ color_index += 1
+
+ showlegend = show_colorbar
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode="closest",
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list(
+ [
+ dict(count=7, label="1w", step="day", stepmode="backward"),
+ dict(count=1, label="1m", step="month", stepmode="backward"),
+ dict(count=6, label="6m", step="month", stepmode="backward"),
+ dict(count=1, label="YTD", step="year", stepmode="todate"),
+ dict(count=1, label="1y", step="year", stepmode="backward"),
+ dict(step="all"),
+ ]
+ )
+ ),
+ type="date",
+ ),
+ )
+
+ data = [scatter_data_dict[k] for k in sorted(scatter_data_dict)]
+ data += [marker_data_dict[k] for k in sorted(marker_data_dict)]
+
+ # fig = dict(
+ # data=data, layout=layout
+ # )
+ fig = go.Figure(data=data, layout=layout)
+ return fig
+
+
+def gantt_colorscale(
+ chart,
+ colors,
+ title,
+ index_col,
+ show_colorbar,
+ bar_width,
+ showgrid_x,
+ showgrid_y,
+ height,
+ width,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=False,
+ show_hover_fill=True,
+):
+ """
+ Refer to FigureFactory.create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+ showlegend = False
+
+ for index in range(len(chart)):
+ task = dict(
+ x0=chart[index]["Start"],
+ x1=chart[index]["Finish"],
+ name=chart[index]["Task"],
+ )
+ if "Description" in chart[index]:
+ task["description"] = chart[index]["Description"]
+ tasks.append(task)
+
+ # create a scatter trace for every task group
+ scatter_data_dict = dict()
+ # create scatter traces for the start- and endpoints
+ marker_data_dict = dict()
+
+ if show_hover_fill:
+ hoverinfo = "name"
+ else:
+ hoverinfo = "skip"
+
+ scatter_data_template = {
+ "x": [],
+ "y": [],
+ "mode": "none",
+ "fill": "toself",
+ "showlegend": False,
+ "hoverinfo": hoverinfo,
+ "legendgroup": "",
+ }
+
+ marker_data_template = {
+ "x": [],
+ "y": [],
+ "mode": "markers",
+ "text": [],
+ "marker": dict(color="", size=1, opacity=0),
+ "name": "",
+ "showlegend": False,
+ "legendgroup": "",
+ }
+
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ # compute the color for task based on indexing column
+ if isinstance(chart[0][index_col], Number):
+ # check that colors has at least 2 colors
+ if len(colors) < 2:
+ raise exceptions.PlotlyError(
+ "You must use at least 2 colors in 'colors' if you "
+ "are using a colorscale. However only the first two "
+ "colors given will be used for the lower and upper "
+ "bounds on the colormap."
+ )
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted
+ # first are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ del tasks[index]["name"]
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]["y0"] = groupID - bar_width
+ tasks[index]["y1"] = groupID + bar_width
+
+ # unlabel color
+ colors = clrs.color_parser(colors, clrs.unlabel_rgb)
+ lowcolor = colors[0]
+ highcolor = colors[1]
+
+ intermed = (chart[index][index_col]) / 100.0
+ intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed)
+ intermed_color = clrs.color_parser(intermed_color, clrs.label_rgb)
+ tasks[index]["fillcolor"] = intermed_color
+ color_id = tasks[index]["fillcolor"]
+
+ if color_id not in scatter_data_dict:
+ scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template)
+
+ scatter_data_dict[color_id]["fillcolor"] = color_id
+ scatter_data_dict[color_id]["name"] = str(chart[index][index_col])
+ scatter_data_dict[color_id]["legendgroup"] = color_id
+
+ # relabel colors with 'rgb'
+ colors = clrs.color_parser(colors, clrs.label_rgb)
+
+ # if there are already values append the gap
+ if len(scatter_data_dict[color_id]["x"]) > 0:
+ # a gap on the scatterplot separates the rectangles from each other
+ scatter_data_dict[color_id]["x"].append(
+ scatter_data_dict[color_id]["x"][-1]
+ )
+ scatter_data_dict[color_id]["y"].append(None)
+
+ xs, ys = _get_corner_points(
+ tasks[index]["x0"],
+ tasks[index]["y0"],
+ tasks[index]["x1"],
+ tasks[index]["y1"],
+ )
+
+ scatter_data_dict[color_id]["x"] += xs
+ scatter_data_dict[color_id]["y"] += ys
+
+ # append dummy markers for showing start and end of interval
+ if color_id not in marker_data_dict:
+ marker_data_dict[color_id] = copy.deepcopy(marker_data_template)
+ marker_data_dict[color_id]["marker"]["color"] = color_id
+ marker_data_dict[color_id]["legendgroup"] = color_id
+
+ marker_data_dict[color_id]["x"].append(tasks[index]["x0"])
+ marker_data_dict[color_id]["x"].append(tasks[index]["x1"])
+ marker_data_dict[color_id]["y"].append(groupID)
+ marker_data_dict[color_id]["y"].append(groupID)
+
+ if "description" in tasks[index]:
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ del tasks[index]["description"]
+ else:
+ marker_data_dict[color_id]["text"].append(None)
+ marker_data_dict[color_id]["text"].append(None)
+
+ # add colorbar to one of the traces randomly just for display
+ if show_colorbar is True:
+ k = list(marker_data_dict.keys())[0]
+ marker_data_dict[k]["marker"].update(
+ dict(
+ colorscale=[[0, colors[0]], [1, colors[1]]],
+ showscale=True,
+ cmax=100,
+ cmin=0,
+ )
+ )
+
+ if isinstance(chart[0][index_col], str):
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ if len(colors) < len(index_vals):
+ raise exceptions.PlotlyError(
+ "Error. The number of colors in 'colors' must be no less "
+ "than the number of unique index values in your group "
+ "column."
+ )
+
+ # make a dictionary assignment to each index value
+ index_vals_dict = {}
+ # define color index
+ c_index = 0
+ for key in index_vals:
+ if c_index > len(colors) - 1:
+ c_index = 0
+ index_vals_dict[key] = colors[c_index]
+ c_index += 1
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted
+ # first are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ del tasks[index]["name"]
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]["y0"] = groupID - bar_width
+ tasks[index]["y1"] = groupID + bar_width
+
+ tasks[index]["fillcolor"] = index_vals_dict[chart[index][index_col]]
+ color_id = tasks[index]["fillcolor"]
+
+ if color_id not in scatter_data_dict:
+ scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template)
+
+ scatter_data_dict[color_id]["fillcolor"] = color_id
+ scatter_data_dict[color_id]["legendgroup"] = color_id
+ scatter_data_dict[color_id]["name"] = str(chart[index][index_col])
+
+ # relabel colors with 'rgb'
+ colors = clrs.color_parser(colors, clrs.label_rgb)
+
+ # if there are already values append the gap
+ if len(scatter_data_dict[color_id]["x"]) > 0:
+ # a gap on the scatterplot separates the rectangles from each other
+ scatter_data_dict[color_id]["x"].append(
+ scatter_data_dict[color_id]["x"][-1]
+ )
+ scatter_data_dict[color_id]["y"].append(None)
+
+ xs, ys = _get_corner_points(
+ tasks[index]["x0"],
+ tasks[index]["y0"],
+ tasks[index]["x1"],
+ tasks[index]["y1"],
+ )
+
+ scatter_data_dict[color_id]["x"] += xs
+ scatter_data_dict[color_id]["y"] += ys
+
+ # append dummy markers for showing start and end of interval
+ if color_id not in marker_data_dict:
+ marker_data_dict[color_id] = copy.deepcopy(marker_data_template)
+ marker_data_dict[color_id]["marker"]["color"] = color_id
+ marker_data_dict[color_id]["legendgroup"] = color_id
+
+ marker_data_dict[color_id]["x"].append(tasks[index]["x0"])
+ marker_data_dict[color_id]["x"].append(tasks[index]["x1"])
+ marker_data_dict[color_id]["y"].append(groupID)
+ marker_data_dict[color_id]["y"].append(groupID)
+
+ if "description" in tasks[index]:
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ del tasks[index]["description"]
+ else:
+ marker_data_dict[color_id]["text"].append(None)
+ marker_data_dict[color_id]["text"].append(None)
+
+ if show_colorbar is True:
+ showlegend = True
+ for k in scatter_data_dict:
+ scatter_data_dict[k]["showlegend"] = showlegend
+ # add colorbar to one of the traces randomly just for display
+ # if show_colorbar is True:
+ # k = list(marker_data_dict.keys())[0]
+ # marker_data_dict[k]["marker"].update(
+ # dict(
+ # colorscale=[[0, colors[0]], [1, colors[1]]],
+ # showscale=True,
+ # cmax=100,
+ # cmin=0,
+ # )
+ # )
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode="closest",
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list(
+ [
+ dict(count=7, label="1w", step="day", stepmode="backward"),
+ dict(count=1, label="1m", step="month", stepmode="backward"),
+ dict(count=6, label="6m", step="month", stepmode="backward"),
+ dict(count=1, label="YTD", step="year", stepmode="todate"),
+ dict(count=1, label="1y", step="year", stepmode="backward"),
+ dict(step="all"),
+ ]
+ )
+ ),
+ type="date",
+ ),
+ )
+
+ data = [scatter_data_dict[k] for k in sorted(scatter_data_dict)]
+ data += [marker_data_dict[k] for k in sorted(marker_data_dict)]
+
+ # fig = dict(
+ # data=data, layout=layout
+ # )
+ fig = go.Figure(data=data, layout=layout)
+ return fig
+
+
+def gantt_dict(
+ chart,
+ colors,
+ title,
+ index_col,
+ show_colorbar,
+ bar_width,
+ showgrid_x,
+ showgrid_y,
+ height,
+ width,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=False,
+ show_hover_fill=True,
+):
+ """
+ Refer to FigureFactory.create_gantt() for docstring
+ """
+
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+ showlegend = False
+
+ for index in range(len(chart)):
+ task = dict(
+ x0=chart[index]["Start"],
+ x1=chart[index]["Finish"],
+ name=chart[index]["Task"],
+ )
+ if "Description" in chart[index]:
+ task["description"] = chart[index]["Description"]
+ tasks.append(task)
+
+ # create a scatter trace for every task group
+ scatter_data_dict = dict()
+ # create scatter traces for the start- and endpoints
+ marker_data_dict = dict()
+
+ if show_hover_fill:
+ hoverinfo = "name"
+ else:
+ hoverinfo = "skip"
+
+ scatter_data_template = {
+ "x": [],
+ "y": [],
+ "mode": "none",
+ "fill": "toself",
+ "hoverinfo": hoverinfo,
+ "legendgroup": "",
+ }
+
+ marker_data_template = {
+ "x": [],
+ "y": [],
+ "mode": "markers",
+ "text": [],
+ "marker": dict(color="", size=1, opacity=0),
+ "name": "",
+ "showlegend": False,
+ }
+
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ # verify each value in index column appears in colors dictionary
+ for key in index_vals:
+ if key not in colors:
+ raise exceptions.PlotlyError(
+ "If you are using colors as a dictionary, all of its "
+ "keys must be all the values in the index column."
+ )
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted first
+ # are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]["name"]
+ del tasks[index]["name"]
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]["y0"] = groupID - bar_width
+ tasks[index]["y1"] = groupID + bar_width
+
+ tasks[index]["fillcolor"] = colors[chart[index][index_col]]
+ color_id = tasks[index]["fillcolor"]
+
+ if color_id not in scatter_data_dict:
+ scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template)
+
+ scatter_data_dict[color_id]["legendgroup"] = color_id
+ scatter_data_dict[color_id]["fillcolor"] = color_id
+
+ # if there are already values append the gap
+ if len(scatter_data_dict[color_id]["x"]) > 0:
+ # a gap on the scatterplot separates the rectangles from each other
+ scatter_data_dict[color_id]["x"].append(
+ scatter_data_dict[color_id]["x"][-1]
+ )
+ scatter_data_dict[color_id]["y"].append(None)
+
+ xs, ys = _get_corner_points(
+ tasks[index]["x0"],
+ tasks[index]["y0"],
+ tasks[index]["x1"],
+ tasks[index]["y1"],
+ )
+
+ scatter_data_dict[color_id]["x"] += xs
+ scatter_data_dict[color_id]["y"] += ys
+
+ # append dummy markers for showing start and end of interval
+ if color_id not in marker_data_dict:
+ marker_data_dict[color_id] = copy.deepcopy(marker_data_template)
+ marker_data_dict[color_id]["marker"]["color"] = color_id
+ marker_data_dict[color_id]["legendgroup"] = color_id
+
+ marker_data_dict[color_id]["x"].append(tasks[index]["x0"])
+ marker_data_dict[color_id]["x"].append(tasks[index]["x1"])
+ marker_data_dict[color_id]["y"].append(groupID)
+ marker_data_dict[color_id]["y"].append(groupID)
+
+ if "description" in tasks[index]:
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ marker_data_dict[color_id]["text"].append(tasks[index]["description"])
+ del tasks[index]["description"]
+ else:
+ marker_data_dict[color_id]["text"].append(None)
+ marker_data_dict[color_id]["text"].append(None)
+
+ if show_colorbar is True:
+ showlegend = True
+
+ for index_value in index_vals:
+ scatter_data_dict[colors[index_value]]["name"] = str(index_value)
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode="closest",
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list(
+ [
+ dict(count=7, label="1w", step="day", stepmode="backward"),
+ dict(count=1, label="1m", step="month", stepmode="backward"),
+ dict(count=6, label="6m", step="month", stepmode="backward"),
+ dict(count=1, label="YTD", step="year", stepmode="todate"),
+ dict(count=1, label="1y", step="year", stepmode="backward"),
+ dict(step="all"),
+ ]
+ )
+ ),
+ type="date",
+ ),
+ )
+
+ data = [scatter_data_dict[k] for k in sorted(scatter_data_dict)]
+ data += [marker_data_dict[k] for k in sorted(marker_data_dict)]
+
+ # fig = dict(
+ # data=data, layout=layout
+ # )
+ fig = go.Figure(data=data, layout=layout)
+ return fig
+
+
+def create_gantt(
+ df,
+ colors=None,
+ index_col=None,
+ show_colorbar=False,
+ reverse_colors=False,
+ title="Gantt Chart",
+ bar_width=0.2,
+ showgrid_x=False,
+ showgrid_y=False,
+ height=600,
+ width=None,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=False,
+ show_hover_fill=True,
+):
+ """
+ **deprecated**, use instead
+ :func:`plotly.express.timeline`.
+
+ Returns figure for a gantt chart
+
+ :param (array|list) df: input data for gantt chart. Must be either a
+ a dataframe or a list. If dataframe, the columns must include
+ 'Task', 'Start' and 'Finish'. Other columns can be included and
+ used for indexing. If a list, its elements must be dictionaries
+ with the same required column headers: 'Task', 'Start' and
+ 'Finish'.
+ :param (str|list|dict|tuple) colors: either a plotly scale name, an
+ rgb or hex color, a color tuple or a list of colors. An rgb color
+ is of the form 'rgb(x, y, z)' where x, y, z belong to the interval
+ [0, 255] and a color tuple is a tuple of the form (a, b, c) where
+ a, b and c belong to [0, 1]. If colors is a list, it must
+ contain the valid color types aforementioned as its members.
+ If a dictionary, all values of the indexing column must be keys in
+ colors.
+ :param (str|float) index_col: the column header (if df is a data
+ frame) that will function as the indexing column. If df is a list,
+ index_col must be one of the keys in all the items of df.
+ :param (bool) show_colorbar: determines if colorbar will be visible.
+ Only applies if values in the index column are numeric.
+ :param (bool) show_hover_fill: enables/disables the hovertext for the
+ filled area of the chart.
+ :param (bool) reverse_colors: reverses the order of selected colors
+ :param (str) title: the title of the chart
+ :param (float) bar_width: the width of the horizontal bars in the plot
+ :param (bool) showgrid_x: show/hide the x-axis grid
+ :param (bool) showgrid_y: show/hide the y-axis grid
+ :param (float) height: the height of the chart
+ :param (float) width: the width of the chart
+
+ Example 1: Simple Gantt Chart
+
+ >>> from plotly.figure_factory import create_gantt
+
+ >>> # Make data for chart
+ >>> df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'),
+ ... dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'),
+ ... dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')]
+
+ >>> # Create a figure
+ >>> fig = create_gantt(df)
+ >>> fig.show()
+
+
+ Example 2: Index by Column with Numerical Entries
+
+ >>> from plotly.figure_factory import create_gantt
+
+ >>> # Make data for chart
+ >>> df = [dict(Task="Job A", Start='2009-01-01',
+ ... Finish='2009-02-30', Complete=10),
+ ... dict(Task="Job B", Start='2009-03-05',
+ ... Finish='2009-04-15', Complete=60),
+ ... dict(Task="Job C", Start='2009-02-20',
+ ... Finish='2009-05-30', Complete=95)]
+
+ >>> # Create a figure with Plotly colorscale
+ >>> fig = create_gantt(df, colors='Blues', index_col='Complete',
+ ... show_colorbar=True, bar_width=0.5,
+ ... showgrid_x=True, showgrid_y=True)
+ >>> fig.show()
+
+
+ Example 3: Index by Column with String Entries
+
+ >>> from plotly.figure_factory import create_gantt
+
+ >>> # Make data for chart
+ >>> df = [dict(Task="Job A", Start='2009-01-01',
+ ... Finish='2009-02-30', Resource='Apple'),
+ ... dict(Task="Job B", Start='2009-03-05',
+ ... Finish='2009-04-15', Resource='Grape'),
+ ... dict(Task="Job C", Start='2009-02-20',
+ ... Finish='2009-05-30', Resource='Banana')]
+
+ >>> # Create a figure with Plotly colorscale
+ >>> fig = create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'],
+ ... index_col='Resource', reverse_colors=True,
+ ... show_colorbar=True)
+ >>> fig.show()
+
+
+ Example 4: Use a dictionary for colors
+
+ >>> from plotly.figure_factory import create_gantt
+ >>> # Make data for chart
+ >>> df = [dict(Task="Job A", Start='2009-01-01',
+ ... Finish='2009-02-30', Resource='Apple'),
+ ... dict(Task="Job B", Start='2009-03-05',
+ ... Finish='2009-04-15', Resource='Grape'),
+ ... dict(Task="Job C", Start='2009-02-20',
+ ... Finish='2009-05-30', Resource='Banana')]
+
+ >>> # Make a dictionary of colors
+ >>> colors = {'Apple': 'rgb(255, 0, 0)',
+ ... 'Grape': 'rgb(170, 14, 200)',
+ ... 'Banana': (1, 1, 0.2)}
+
+ >>> # Create a figure with Plotly colorscale
+ >>> fig = create_gantt(df, colors=colors, index_col='Resource',
+ ... show_colorbar=True)
+
+ >>> fig.show()
+
+ Example 5: Use a pandas dataframe
+
+ >>> from plotly.figure_factory import create_gantt
+ >>> import pandas as pd
+
+ >>> # Make data as a dataframe
+ >>> df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10],
+ ... ['Fast', '2011-01-01', '2012-06-05', 55],
+ ... ['Eat', '2012-01-05', '2013-07-05', 94]],
+ ... columns=['Task', 'Start', 'Finish', 'Complete'])
+
+ >>> # Create a figure with Plotly colorscale
+ >>> fig = create_gantt(df, colors='Blues', index_col='Complete',
+ ... show_colorbar=True, bar_width=0.5,
+ ... showgrid_x=True, showgrid_y=True)
+ >>> fig.show()
+ """
+ # validate gantt input data
+ chart = validate_gantt(df)
+
+ if index_col:
+ if index_col not in chart[0]:
+ raise exceptions.PlotlyError(
+ "In order to use an indexing column and assign colors to "
+ "the values of the index, you must choose an actual "
+ "column name in the dataframe or key if a list of "
+ "dictionaries is being used."
+ )
+
+ # validate gantt index column
+ index_list = []
+ for dictionary in chart:
+ index_list.append(dictionary[index_col])
+ utils.validate_index(index_list)
+
+ # Validate colors
+ if isinstance(colors, dict):
+ colors = clrs.validate_colors_dict(colors, "rgb")
+ else:
+ colors = clrs.validate_colors(colors, "rgb")
+
+ if reverse_colors is True:
+ colors.reverse()
+
+ if not index_col:
+ if isinstance(colors, dict):
+ raise exceptions.PlotlyError(
+ "Error. You have set colors to a dictionary but have not "
+ "picked an index. An index is required if you are "
+ "assigning colors to particular values in a dictionary."
+ )
+ fig = gantt(
+ chart,
+ colors,
+ title,
+ bar_width,
+ showgrid_x,
+ showgrid_y,
+ height,
+ width,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=group_tasks,
+ show_hover_fill=show_hover_fill,
+ show_colorbar=show_colorbar,
+ )
+ return fig
+ else:
+ if not isinstance(colors, dict):
+ fig = gantt_colorscale(
+ chart,
+ colors,
+ title,
+ index_col,
+ show_colorbar,
+ bar_width,
+ showgrid_x,
+ showgrid_y,
+ height,
+ width,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=group_tasks,
+ show_hover_fill=show_hover_fill,
+ )
+ return fig
+ else:
+ fig = gantt_dict(
+ chart,
+ colors,
+ title,
+ index_col,
+ show_colorbar,
+ bar_width,
+ showgrid_x,
+ showgrid_y,
+ height,
+ width,
+ tasks=None,
+ task_names=None,
+ data=None,
+ group_tasks=group_tasks,
+ show_hover_fill=show_hover_fill,
+ )
+ return fig
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py
new file mode 100644
index 0000000..c763522
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py
@@ -0,0 +1,526 @@
+from plotly.express._core import build_dataframe
+from plotly.express._doc import make_docstring
+from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox
+import narwhals.stable.v1 as nw
+import numpy as np
+
+
+def _project_latlon_to_wgs84(lat, lon):
+ """
+ Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map
+ """
+ x = lon * np.pi / 180
+ y = np.arctanh(np.sin(lat * np.pi / 180))
+ return x, y
+
+
+def _project_wgs84_to_latlon(x, y):
+ """
+ Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map
+ """
+ lon = x * 180 / np.pi
+ lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi
+ return lat, lon
+
+
+def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim):
+ """
+ Get the mapbox zoom level given bounds and a figure dimension
+ Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds
+ """
+
+ scale = (
+ 2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256
+ )
+ WORLD_DIM = {"height": 256 * scale, "width": 256 * scale}
+ ZOOM_MAX = 18
+
+ def latRad(lat):
+ sin = np.sin(lat * np.pi / 180)
+ radX2 = np.log((1 + sin) / (1 - sin)) / 2
+ return max(min(radX2, np.pi), -np.pi) / 2
+
+ def zoom(mapPx, worldPx, fraction):
+ return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2)
+
+ latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi
+
+ lngDiff = lon_max - lon_min
+ lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360
+
+ latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction)
+ lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction)
+
+ return min(latZoom, lngZoom, ZOOM_MAX)
+
+
+def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count):
+ """
+ Computes the aggregation at hexagonal bin level.
+ Also defines the coordinates of the hexagons for plotting.
+ The binning is inspired by matplotlib's implementation.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Array of x values (shape N)
+ y : np.ndarray
+ Array of y values (shape N)
+ x_range : np.ndarray
+ Min and max x (shape 2)
+ y_range : np.ndarray
+ Min and max y (shape 2)
+ color : np.ndarray
+ Metric to aggregate at hexagon level (shape N)
+ nx : int
+ Number of hexagons horizontally
+ agg_func : function
+ Numpy compatible aggregator, this function must take a one-dimensional
+ np.ndarray as input and output a scalar
+ min_count : int
+ Minimum number of points in the hexagon for the hexagon to be displayed
+
+ Returns
+ -------
+ np.ndarray
+ X coordinates of each hexagon (shape M x 6)
+ np.ndarray
+ Y coordinates of each hexagon (shape M x 6)
+ np.ndarray
+ Centers of the hexagons (shape M x 2)
+ np.ndarray
+ Aggregated value in each hexagon (shape M)
+
+ """
+ xmin = x_range.min()
+ xmax = x_range.max()
+ ymin = y_range.min()
+ ymax = y_range.max()
+
+ # In the x-direction, the hexagons exactly cover the region from
+ # xmin to xmax. Need some padding to avoid roundoff errors.
+ padding = 1.0e-9 * (xmax - xmin)
+ xmin -= padding
+ xmax += padding
+
+ Dx = xmax - xmin
+ Dy = ymax - ymin
+ if Dx == 0 and Dy > 0:
+ dx = Dy / nx
+ elif Dx == 0 and Dy == 0:
+ dx, _ = _project_latlon_to_wgs84(1, 1)
+ else:
+ dx = Dx / nx
+ dy = dx * np.sqrt(3)
+ ny = np.ceil(Dy / dy).astype(int)
+
+ # Center the hexagons vertically since we only want regular hexagons
+ ymin -= (ymin + dy * ny - ymax) / 2
+
+ x = (x - xmin) / dx
+ y = (y - ymin) / dy
+ ix1 = np.round(x).astype(int)
+ iy1 = np.round(y).astype(int)
+ ix2 = np.floor(x).astype(int)
+ iy2 = np.floor(y).astype(int)
+
+ nx1 = nx + 1
+ ny1 = ny + 1
+ nx2 = nx
+ ny2 = ny
+ n = nx1 * ny1 + nx2 * ny2
+
+ d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
+ d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
+ bdist = d1 < d2
+
+ if color is None:
+ lattice1 = np.zeros((nx1, ny1))
+ lattice2 = np.zeros((nx2, ny2))
+ c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
+ c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
+ np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
+ np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
+ if min_count is not None:
+ lattice1[lattice1 < min_count] = np.nan
+ lattice2[lattice2 < min_count] = np.nan
+ accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
+ good_idxs = ~np.isnan(accum)
+ else:
+ if min_count is None:
+ min_count = 1
+
+ # create accumulation arrays
+ lattice1 = np.empty((nx1, ny1), dtype=object)
+ for i in range(nx1):
+ for j in range(ny1):
+ lattice1[i, j] = []
+ lattice2 = np.empty((nx2, ny2), dtype=object)
+ for i in range(nx2):
+ for j in range(ny2):
+ lattice2[i, j] = []
+
+ for i in range(len(x)):
+ if bdist[i]:
+ if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
+ lattice1[ix1[i], iy1[i]].append(color[i])
+ else:
+ if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
+ lattice2[ix2[i], iy2[i]].append(color[i])
+
+ for i in range(nx1):
+ for j in range(ny1):
+ vals = lattice1[i, j]
+ if len(vals) >= min_count:
+ lattice1[i, j] = agg_func(vals)
+ else:
+ lattice1[i, j] = np.nan
+ for i in range(nx2):
+ for j in range(ny2):
+ vals = lattice2[i, j]
+ if len(vals) >= min_count:
+ lattice2[i, j] = agg_func(vals)
+ else:
+ lattice2[i, j] = np.nan
+
+ accum = np.hstack(
+ (lattice1.astype(float).ravel(), lattice2.astype(float).ravel())
+ )
+ good_idxs = ~np.isnan(accum)
+
+ agreggated_value = accum[good_idxs]
+
+ centers = np.zeros((n, 2), float)
+ centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
+ centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
+ centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
+ centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5
+ centers[:, 0] *= dx
+ centers[:, 1] *= dy
+ centers[:, 0] += xmin
+ centers[:, 1] += ymin
+ centers = centers[good_idxs]
+
+ # Define normalised regular hexagon coordinates
+ hx = [0, 0.5, 0.5, 0, -0.5, -0.5]
+ hy = [
+ -0.5 / np.cos(np.pi / 6),
+ -0.5 * np.tan(np.pi / 6),
+ 0.5 * np.tan(np.pi / 6),
+ 0.5 / np.cos(np.pi / 6),
+ 0.5 * np.tan(np.pi / 6),
+ -0.5 * np.tan(np.pi / 6),
+ ]
+
+ # Number of hexagons needed
+ m = len(centers)
+
+ # Coordinates for all hexagonal patches
+ hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0])
+ hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1])
+
+ return hxs, hys, centers, agreggated_value
+
+
+def _compute_wgs84_hexbin(
+ lat=None,
+ lon=None,
+ lat_range=None,
+ lon_range=None,
+ color=None,
+ nx=None,
+ agg_func=None,
+ min_count=None,
+ native_namespace=None,
+):
+ """
+ Computes the lat-lon aggregation at hexagonal bin level.
+ Latitude and longitude need to be projected to WGS84 before aggregating
+ in order to display regular hexagons on the map.
+
+ Parameters
+ ----------
+ lat : np.ndarray
+ Array of latitudes (shape N)
+ lon : np.ndarray
+ Array of longitudes (shape N)
+ lat_range : np.ndarray
+ Min and max latitudes (shape 2)
+ lon_range : np.ndarray
+ Min and max longitudes (shape 2)
+ color : np.ndarray
+ Metric to aggregate at hexagon level (shape N)
+ nx : int
+ Number of hexagons horizontally
+ agg_func : function
+ Numpy compatible aggregator, this function must take a one-dimensional
+ np.ndarray as input and output a scalar
+ min_count : int
+ Minimum number of points in the hexagon for the hexagon to be displayed
+
+ Returns
+ -------
+ np.ndarray
+ Lat coordinates of each hexagon (shape M x 6)
+ np.ndarray
+ Lon coordinates of each hexagon (shape M x 6)
+ nw.Series
+ Unique id for each hexagon, to be used in the geojson data (shape M)
+ np.ndarray
+ Aggregated value in each hexagon (shape M)
+
+ """
+ # Project to WGS 84
+ x, y = _project_latlon_to_wgs84(lat, lon)
+
+ if lat_range is None:
+ lat_range = np.array([lat.min(), lat.max()])
+ if lon_range is None:
+ lon_range = np.array([lon.min(), lon.max()])
+
+ x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
+
+ hxs, hys, centers, agreggated_value = _compute_hexbin(
+ x, y, x_range, y_range, color, nx, agg_func, min_count
+ )
+
+ # Convert back to lat-lon
+ hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys)
+
+ # Create unique feature id based on hexagon center
+ centers = centers.astype(str)
+ hexagons_ids = (
+ nw.from_dict(
+ {"x1": centers[:, 0], "x2": centers[:, 1]},
+ native_namespace=native_namespace,
+ )
+ .select(hexagons_ids=nw.concat_str([nw.col("x1"), nw.col("x2")], separator=","))
+ .get_column("hexagons_ids")
+ )
+
+ return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value
+
+
+def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
+ """
+ Creates a geojson of hexagonal features based on the outputs of
+ _compute_wgs84_hexbin
+ """
+ features = []
+ if ids is None:
+ ids = np.arange(len(hexagons_lats))
+ for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids):
+ points = np.array([lon, lat]).T.tolist()
+ points.append(points[0])
+ features.append(
+ dict(
+ type="Feature",
+ id=idx,
+ geometry=dict(type="Polygon", coordinates=[points]),
+ )
+ )
+ return dict(type="FeatureCollection", features=features)
+
+
+def create_hexbin_mapbox(
+ data_frame=None,
+ lat=None,
+ lon=None,
+ color=None,
+ nx_hexagon=5,
+ agg_func=None,
+ animation_frame=None,
+ color_discrete_sequence=None,
+ color_discrete_map={},
+ labels={},
+ color_continuous_scale=None,
+ range_color=None,
+ color_continuous_midpoint=None,
+ opacity=None,
+ zoom=None,
+ center=None,
+ mapbox_style=None,
+ title=None,
+ template=None,
+ width=None,
+ height=None,
+ min_count=None,
+ show_original_data=False,
+ original_data_marker=None,
+):
+ """
+ Returns a figure aggregating scattered points into connected hexagons
+ """
+ args = build_dataframe(args=locals(), constructor=None)
+ native_namespace = nw.get_native_namespace(args["data_frame"])
+ if agg_func is None:
+ agg_func = np.mean
+
+ lat_range = (
+ args["data_frame"]
+ .select(
+ nw.min(args["lat"]).name.suffix("_min"),
+ nw.max(args["lat"]).name.suffix("_max"),
+ )
+ .to_numpy()
+ .squeeze()
+ )
+
+ lon_range = (
+ args["data_frame"]
+ .select(
+ nw.min(args["lon"]).name.suffix("_min"),
+ nw.max(args["lon"]).name.suffix("_max"),
+ )
+ .to_numpy()
+ .squeeze()
+ )
+
+ hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
+ lat=args["data_frame"].get_column(args["lat"]).to_numpy(),
+ lon=args["data_frame"].get_column(args["lon"]).to_numpy(),
+ lat_range=lat_range,
+ lon_range=lon_range,
+ color=None,
+ nx=nx_hexagon,
+ agg_func=agg_func,
+ min_count=min_count,
+ native_namespace=native_namespace,
+ )
+
+ geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids)
+
+ if zoom is None:
+ if height is None and width is None:
+ mapDim = dict(height=450, width=450)
+ elif height is None and width is not None:
+ mapDim = dict(height=450, width=width)
+ elif height is not None and width is None:
+ mapDim = dict(height=height, width=height)
+ else:
+ mapDim = dict(height=height, width=width)
+ zoom = _getBoundsZoomLevel(
+ lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim
+ )
+
+ if center is None:
+ center = dict(lat=lat_range.mean(), lon=lon_range.mean())
+
+ if args["animation_frame"] is not None:
+ groups = dict(
+ args["data_frame"]
+ .group_by(args["animation_frame"], drop_null_keys=True)
+ .__iter__()
+ )
+ else:
+ groups = {(0,): args["data_frame"]}
+
+ agg_data_frame_list = []
+ for key, df in groups.items():
+ _, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
+ lat=df.get_column(args["lat"]).to_numpy(),
+ lon=df.get_column(args["lon"]).to_numpy(),
+ lat_range=lat_range,
+ lon_range=lon_range,
+ color=df.get_column(args["color"]).to_numpy() if args["color"] else None,
+ nx=nx_hexagon,
+ agg_func=agg_func,
+ min_count=min_count,
+ native_namespace=native_namespace,
+ )
+ agg_data_frame_list.append(
+ nw.from_dict(
+ {
+ "frame": [key[0]] * len(hexagons_ids),
+ "locations": hexagons_ids,
+ "color": aggregated_value,
+ },
+ native_namespace=native_namespace,
+ )
+ )
+
+ agg_data_frame = nw.concat(agg_data_frame_list, how="vertical").with_columns(
+ color=nw.col("color").cast(nw.Int64)
+ )
+
+ if range_color is None:
+ range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()]
+
+ fig = choropleth_mapbox(
+ data_frame=agg_data_frame.to_native(),
+ geojson=geojson,
+ locations="locations",
+ color="color",
+ hover_data={"color": True, "locations": False, "frame": False},
+ animation_frame=("frame" if args["animation_frame"] is not None else None),
+ color_discrete_sequence=color_discrete_sequence,
+ color_discrete_map=color_discrete_map,
+ labels=labels,
+ color_continuous_scale=color_continuous_scale,
+ range_color=range_color,
+ color_continuous_midpoint=color_continuous_midpoint,
+ opacity=opacity,
+ zoom=zoom,
+ center=center,
+ mapbox_style=mapbox_style,
+ title=title,
+ template=template,
+ width=width,
+ height=height,
+ )
+
+ if show_original_data:
+ original_fig = scatter_mapbox(
+ data_frame=(
+ args["data_frame"].sort(
+ by=args["animation_frame"], descending=False, nulls_last=True
+ )
+ if args["animation_frame"] is not None
+ else args["data_frame"]
+ ).to_native(),
+ lat=args["lat"],
+ lon=args["lon"],
+ animation_frame=args["animation_frame"],
+ )
+ original_fig.data[0].hoverinfo = "skip"
+ original_fig.data[0].hovertemplate = None
+ original_fig.data[0].marker = original_data_marker
+
+ fig.add_trace(original_fig.data[0])
+
+ if args["animation_frame"] is not None:
+ for i in range(len(original_fig.frames)):
+ original_fig.frames[i].data[0].hoverinfo = "skip"
+ original_fig.frames[i].data[0].hovertemplate = None
+ original_fig.frames[i].data[0].marker = original_data_marker
+
+ fig.frames[i].data = [
+ fig.frames[i].data[0],
+ original_fig.frames[i].data[0],
+ ]
+
+ return fig
+
+
+create_hexbin_mapbox.__doc__ = make_docstring(
+ create_hexbin_mapbox,
+ override_dict=dict(
+ nx_hexagon=["int", "Number of hexagons (horizontally) to be created"],
+ agg_func=[
+ "function",
+ "Numpy array aggregator, it must take as input a 1D array",
+ "and output a scalar value.",
+ ],
+ min_count=[
+ "int",
+ "Minimum number of points in a hexagon for it to be displayed.",
+ "If None and color is not set, display all hexagons.",
+ "If None and color is set, only display hexagons that contain points.",
+ ],
+ show_original_data=[
+ "bool",
+ "Whether to show the original data on top of the hexbin aggregation.",
+ ],
+ original_data_marker=["dict", "Scattermapbox marker options."],
+ ),
+)
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py
new file mode 100644
index 0000000..a0b1b48
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py
@@ -0,0 +1,295 @@
+from plotly import exceptions
+from plotly.graph_objs import graph_objs
+from plotly.figure_factory import utils
+
+
+# Default colours for finance charts
+_DEFAULT_INCREASING_COLOR = "#3D9970" # http://clrs.cc
+_DEFAULT_DECREASING_COLOR = "#FF4136"
+
+
+def validate_ohlc(open, high, low, close, direction, **kwargs):
+ """
+ ohlc and candlestick specific validations
+
+ Specifically, this checks that the high value is the greatest value and
+ the low value is the lowest value in each unit.
+
+ See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
+ for params
+
+ :raises: (PlotlyError) If the high value is not the greatest value in
+ each unit.
+ :raises: (PlotlyError) If the low value is not the lowest value in each
+ unit.
+ :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
+ """
+ for lst in [open, low, close]:
+ for index in range(len(high)):
+ if high[index] < lst[index]:
+ raise exceptions.PlotlyError(
+ "Oops! Looks like some of "
+ "your high values are less "
+ "the corresponding open, "
+ "low, or close values. "
+ "Double check that your data "
+ "is entered in O-H-L-C order"
+ )
+
+ for lst in [open, high, close]:
+ for index in range(len(low)):
+ if low[index] > lst[index]:
+ raise exceptions.PlotlyError(
+ "Oops! Looks like some of "
+ "your low values are greater "
+ "than the corresponding high"
+ ", open, or close values. "
+ "Double check that your data "
+ "is entered in O-H-L-C order"
+ )
+
+ direction_opts = ("increasing", "decreasing", "both")
+ if direction not in direction_opts:
+ raise exceptions.PlotlyError(
+ "direction must be defined as 'increasing', 'decreasing', or 'both'"
+ )
+
+
+def make_increasing_ohlc(open, high, low, close, dates, **kwargs):
+ """
+ Makes increasing ohlc sticks
+
+ _make_increasing_ohlc() and _make_decreasing_ohlc separate the
+ increasing trace from the decreasing trace so kwargs (such as
+ color) can be passed separately to increasing or decreasing traces
+ when direction is set to 'increasing' or 'decreasing' in
+ FigureFactory.create_candlestick()
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
+ sticks.
+ """
+ (flat_increase_x, flat_increase_y, text_increase) = _OHLC(
+ open, high, low, close, dates
+ ).get_increase()
+
+ if "name" in kwargs:
+ showlegend = True
+ else:
+ kwargs.setdefault("name", "Increasing")
+ showlegend = False
+
+ kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR, width=1))
+ kwargs.setdefault("text", text_increase)
+
+ ohlc_incr = dict(
+ type="scatter",
+ x=flat_increase_x,
+ y=flat_increase_y,
+ mode="lines",
+ showlegend=showlegend,
+ **kwargs,
+ )
+ return ohlc_incr
+
+
+def make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
+ """
+ Makes decreasing ohlc sticks
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
+ sticks.
+ """
+ (flat_decrease_x, flat_decrease_y, text_decrease) = _OHLC(
+ open, high, low, close, dates
+ ).get_decrease()
+
+ kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR, width=1))
+ kwargs.setdefault("text", text_decrease)
+ kwargs.setdefault("showlegend", False)
+ kwargs.setdefault("name", "Decreasing")
+
+ ohlc_decr = dict(
+ type="scatter", x=flat_decrease_x, y=flat_decrease_y, mode="lines", **kwargs
+ )
+ return ohlc_decr
+
+
+def create_ohlc(open, high, low, close, dates=None, direction="both", **kwargs):
+ """
+ **deprecated**, use instead the plotly.graph_objects trace
+ :class:`plotly.graph_objects.Ohlc`
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing
+ :param (list) dates: list of datetime objects. Default: None
+ :param (string) direction: direction can be 'increasing', 'decreasing',
+ or 'both'. When the direction is 'increasing', the returned figure
+ consists of all units where the close value is greater than the
+ corresponding open value, and when the direction is 'decreasing',
+ the returned figure consists of all units where the close value is
+ less than or equal to the corresponding open value. When the
+ direction is 'both', both increasing and decreasing units are
+ returned. Default: 'both'
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
+ These kwargs describe other attributes about the ohlc Scatter trace
+ such as the color or the legend name. For more information on valid
+ kwargs call help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of an ohlc chart figure.
+
+ Example 1: Simple OHLC chart from a Pandas DataFrame
+
+ >>> from plotly.figure_factory import create_ohlc
+ >>> from datetime import datetime
+
+ >>> import pandas as pd
+ >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
+ >>> fig = create_ohlc(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], dates=df.index)
+ >>> fig.show()
+ """
+ if dates is not None:
+ utils.validate_equal_length(open, high, low, close, dates)
+ else:
+ utils.validate_equal_length(open, high, low, close)
+ validate_ohlc(open, high, low, close, direction, **kwargs)
+
+ if direction == "increasing":
+ ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
+ data = [ohlc_incr]
+ elif direction == "decreasing":
+ ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
+ data = [ohlc_decr]
+ else:
+ ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
+ ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
+ data = [ohlc_incr, ohlc_decr]
+
+ layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest")
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _OHLC(object):
+ """
+ Refer to FigureFactory.create_ohlc_increase() for docstring.
+ """
+
+ def __init__(self, open, high, low, close, dates, **kwargs):
+ self.open = open
+ self.high = high
+ self.low = low
+ self.close = close
+ self.empty = [None] * len(open)
+ self.dates = dates
+
+ self.all_x = []
+ self.all_y = []
+ self.increase_x = []
+ self.increase_y = []
+ self.decrease_x = []
+ self.decrease_y = []
+ self.get_all_xy()
+ self.separate_increase_decrease()
+
+ def get_all_xy(self):
+ """
+ Zip data to create OHLC shape
+
+ OHLC shape: low to high vertical bar with
+ horizontal branches for open and close values.
+ If dates were added, the smallest date difference is calculated and
+ multiplied by .2 to get the length of the open and close branches.
+ If no date data was provided, the x-axis is a list of integers and the
+ length of the open and close branches is .2.
+ """
+ self.all_y = list(
+ zip(
+ self.open,
+ self.open,
+ self.high,
+ self.low,
+ self.close,
+ self.close,
+ self.empty,
+ )
+ )
+ if self.dates is not None:
+ date_dif = []
+ for i in range(len(self.dates) - 1):
+ date_dif.append(self.dates[i + 1] - self.dates[i])
+ date_dif_min = (min(date_dif)) / 5
+ self.all_x = [
+ [x - date_dif_min, x, x, x, x, x + date_dif_min, None]
+ for x in self.dates
+ ]
+ else:
+ self.all_x = [
+ [x - 0.2, x, x, x, x, x + 0.2, None] for x in range(len(self.open))
+ ]
+
+ def separate_increase_decrease(self):
+ """
+ Separate data into two groups: increase and decrease
+
+ (1) Increase, where close > open and
+ (2) Decrease, where close <= open
+ """
+ for index in range(len(self.open)):
+ if self.close[index] is None:
+ pass
+ elif self.close[index] > self.open[index]:
+ self.increase_x.append(self.all_x[index])
+ self.increase_y.append(self.all_y[index])
+ else:
+ self.decrease_x.append(self.all_x[index])
+ self.decrease_y.append(self.all_y[index])
+
+ def get_increase(self):
+ """
+ Flatten increase data and get increase text
+
+ :rtype (list, list, list): flat_increase_x: x-values for the increasing
+ trace, flat_increase_y: y=values for the increasing trace and
+ text_increase: hovertext for the increasing trace
+ """
+ flat_increase_x = utils.flatten(self.increase_x)
+ flat_increase_y = utils.flatten(self.increase_y)
+ text_increase = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
+ len(self.increase_x)
+ )
+
+ return flat_increase_x, flat_increase_y, text_increase
+
+ def get_decrease(self):
+ """
+ Flatten decrease data and get decrease text
+
+ :rtype (list, list, list): flat_decrease_x: x-values for the decreasing
+ trace, flat_decrease_y: y=values for the decreasing trace and
+ text_decrease: hovertext for the decreasing trace
+ """
+ flat_decrease_x = utils.flatten(self.decrease_x)
+ flat_decrease_y = utils.flatten(self.decrease_y)
+ text_decrease = ("Open", "Open", "High", "Low", "Close", "Close", "") * (
+ len(self.decrease_x)
+ )
+
+ return flat_decrease_x, flat_decrease_y, text_decrease
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py
new file mode 100644
index 0000000..fa18222
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py
@@ -0,0 +1,265 @@
+import math
+
+from plotly import exceptions
+from plotly.graph_objs import graph_objs
+from plotly.figure_factory import utils
+
+
+def create_quiver(
+ x, y, u, v, scale=0.1, arrow_scale=0.3, angle=math.pi / 9, scaleratio=None, **kwargs
+):
+ """
+ Returns data for a quiver plot.
+
+ :param (list|ndarray) x: x coordinates of the arrow locations
+ :param (list|ndarray) y: y coordinates of the arrow locations
+ :param (list|ndarray) u: x components of the arrow vectors
+ :param (list|ndarray) v: y components of the arrow vectors
+ :param (float in [0,1]) scale: scales size of the arrows(ideally to
+ avoid overlap). Default = .1
+ :param (float in [0,1]) arrow_scale: value multiplied to length of barb
+ to get length of arrowhead. Default = .3
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (positive float) scaleratio: the ratio between the scale of the y-axis
+ and the scale of the x-axis (scale_y / scale_x). Default = None, the
+ scale ratio is not fixed.
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter
+ for more information on valid kwargs call
+ help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of quiver figure.
+
+ Example 1: Trivial Quiver
+
+ >>> from plotly.figure_factory import create_quiver
+ >>> import math
+
+ >>> # 1 Arrow from (0,0) to (1,1)
+ >>> fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)
+ >>> fig.show()
+
+
+ Example 2: Quiver plot using meshgrid
+
+ >>> from plotly.figure_factory import create_quiver
+
+ >>> import numpy as np
+ >>> import math
+
+ >>> # Add data
+ >>> x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
+ >>> u = np.cos(x)*y
+ >>> v = np.sin(x)*y
+
+ >>> #Create quiver
+ >>> fig = create_quiver(x, y, u, v)
+ >>> fig.show()
+
+
+ Example 3: Styling the quiver plot
+
+ >>> from plotly.figure_factory import create_quiver
+ >>> import numpy as np
+ >>> import math
+
+ >>> # Add data
+ >>> x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
+ ... np.arange(-math.pi, math.pi, .5))
+ >>> u = np.cos(x)*y
+ >>> v = np.sin(x)*y
+
+ >>> # Create quiver
+ >>> fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
+ ... name='Wind Velocity', line=dict(width=1))
+
+ >>> # Add title to layout
+ >>> fig.update_layout(title='Quiver Plot') # doctest: +SKIP
+ >>> fig.show()
+
+
+ Example 4: Forcing a fix scale ratio to maintain the arrow length
+
+ >>> from plotly.figure_factory import create_quiver
+ >>> import numpy as np
+
+ >>> # Add data
+ >>> x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5))
+ >>> u = x
+ >>> v = y
+ >>> angle = np.arctan(v / u)
+ >>> norm = 0.25
+ >>> u = norm * np.cos(angle)
+ >>> v = norm * np.sin(angle)
+
+ >>> # Create quiver with a fix scale ratio
+ >>> fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5)
+ >>> fig.show()
+ """
+ utils.validate_equal_length(x, y, u, v)
+ utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)
+
+ if scaleratio is None:
+ quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle)
+ else:
+ quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio)
+
+ barb_x, barb_y = quiver_obj.get_barbs()
+ arrow_x, arrow_y = quiver_obj.get_quiver_arrows()
+
+ quiver_plot = graph_objs.Scatter(
+ x=barb_x + arrow_x, y=barb_y + arrow_y, mode="lines", **kwargs
+ )
+
+ data = [quiver_plot]
+
+ if scaleratio is None:
+ layout = graph_objs.Layout(hovermode="closest")
+ else:
+ layout = graph_objs.Layout(
+ hovermode="closest", yaxis=dict(scaleratio=scaleratio, scaleanchor="x")
+ )
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Quiver(object):
+ """
+ Refer to FigureFactory.create_quiver() for docstring
+ """
+
+ def __init__(self, x, y, u, v, scale, arrow_scale, angle, scaleratio=1, **kwargs):
+ try:
+ x = utils.flatten(x)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ y = utils.flatten(y)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ u = utils.flatten(u)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ v = utils.flatten(v)
+ except exceptions.PlotlyError:
+ pass
+
+ self.x = x
+ self.y = y
+ self.u = u
+ self.v = v
+ self.scale = scale
+ self.scaleratio = scaleratio
+ self.arrow_scale = arrow_scale
+ self.angle = angle
+ self.end_x = []
+ self.end_y = []
+ self.scale_uv()
+ barb_x, barb_y = self.get_barbs()
+ arrow_x, arrow_y = self.get_quiver_arrows()
+
+ def scale_uv(self):
+ """
+ Scales u and v to avoid overlap of the arrows.
+
+ u and v are added to x and y to get the
+ endpoints of the arrows so a smaller scale value will
+ result in less overlap of arrows.
+ """
+ self.u = [i * self.scale * self.scaleratio for i in self.u]
+ self.v = [i * self.scale for i in self.v]
+
+ def get_barbs(self):
+ """
+ Creates x and y startpoint and endpoint pairs
+
+ After finding the endpoint of each barb this zips startpoint and
+ endpoint pairs to create 2 lists: x_values for barbs and y values
+ for barbs
+
+ :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
+ x_value pairs separated by a None to create the barb of the arrow,
+ and list of startpoint and endpoint y_value pairs separated by a
+ None to create the barb of the arrow.
+ """
+ self.end_x = [i + j for i, j in zip(self.x, self.u)]
+ self.end_y = [i + j for i, j in zip(self.y, self.v)]
+ empty = [None] * len(self.x)
+ barb_x = utils.flatten(zip(self.x, self.end_x, empty))
+ barb_y = utils.flatten(zip(self.y, self.end_y, empty))
+ return barb_x, barb_y
+
+ def get_quiver_arrows(self):
+ """
+ Creates lists of x and y values to plot the arrows
+
+ Gets length of each barb then calculates the length of each side of
+ the arrow. Gets angle of barb and applies angle to each side of the
+ arrowhead. Next uses arrow_scale to scale the length of arrowhead and
+ creates x and y values for arrowhead point1 and point2. Finally x and y
+ values for point1, endpoint and point2s for each arrowhead are
+ separated by a None and zipped to create lists of x and y values for
+ the arrows.
+
+ :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
+ x_values separated by a None to create the arrowhead and list of
+ point1, endpoint, point2 y_values separated by a None to create
+ the barb of the arrow.
+ """
+ dif_x = [i - j for i, j in zip(self.end_x, self.x)]
+ dif_y = [i - j for i, j in zip(self.end_y, self.y)]
+
+ # Get barb lengths(default arrow length = 30% barb length)
+ barb_len = [None] * len(self.x)
+ for index in range(len(barb_len)):
+ barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index])
+
+ # Make arrow lengths
+ arrow_len = [None] * len(self.x)
+ arrow_len = [i * self.arrow_scale for i in barb_len]
+
+ # Get barb angles
+ barb_ang = [None] * len(self.x)
+ for index in range(len(barb_ang)):
+ barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio)
+
+ # Set angles to create arrow
+ ang1 = [i + self.angle for i in barb_ang]
+ ang2 = [i - self.angle for i in barb_ang]
+
+ cos_ang1 = [None] * len(ang1)
+ for index in range(len(ang1)):
+ cos_ang1[index] = math.cos(ang1[index])
+ seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
+
+ sin_ang1 = [None] * len(ang1)
+ for index in range(len(ang1)):
+ sin_ang1[index] = math.sin(ang1[index])
+ seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
+
+ cos_ang2 = [None] * len(ang2)
+ for index in range(len(ang2)):
+ cos_ang2[index] = math.cos(ang2[index])
+ seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
+
+ sin_ang2 = [None] * len(ang2)
+ for index in range(len(ang2)):
+ sin_ang2[index] = math.sin(ang2[index])
+ seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
+
+ # Set coordinates to create arrow
+ for index in range(len(self.end_x)):
+ point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)]
+ point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
+ point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)]
+ point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
+
+ # Combine lists to create arrow
+ empty = [None] * len(self.end_x)
+ arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty))
+ arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty))
+ return arrow_x, arrow_y
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py
new file mode 100644
index 0000000..7589527
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py
@@ -0,0 +1,1135 @@
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.subplots import make_subplots
+
+pd = optional_imports.get_module("pandas")
+
+DIAG_CHOICES = ["scatter", "histogram", "box"]
+VALID_COLORMAP_TYPES = ["cat", "seq"]
+
+
+def endpts_to_intervals(endpts):
+ """
+ Returns a list of intervals for categorical colormaps
+
+ Accepts a list or tuple of sequentially increasing numbers and returns
+ a list representation of the mathematical intervals with these numbers
+ as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
+
+ :raises: (PlotlyError) If input is not a list or tuple
+ :raises: (PlotlyError) If the input contains a string
+ :raises: (PlotlyError) If any number does not increase after the
+ previous one in the sequence
+ """
+ length = len(endpts)
+ # Check if endpts is a list or tuple
+ if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
+ raise exceptions.PlotlyError(
+ "The intervals_endpts argument must "
+ "be a list or tuple of a sequence "
+ "of increasing numbers."
+ )
+ # Check if endpts contains only numbers
+ for item in endpts:
+ if isinstance(item, str):
+ raise exceptions.PlotlyError(
+ "The intervals_endpts argument "
+ "must be a list or tuple of a "
+ "sequence of increasing "
+ "numbers."
+ )
+ # Check if numbers in endpts are increasing
+ for k in range(length - 1):
+ if endpts[k] >= endpts[k + 1]:
+ raise exceptions.PlotlyError(
+ "The intervals_endpts argument "
+ "must be a list or tuple of a "
+ "sequence of increasing "
+ "numbers."
+ )
+ else:
+ intervals = []
+ # add -inf to intervals
+ intervals.append([float("-inf"), endpts[0]])
+ for k in range(length - 1):
+ interval = []
+ interval.append(endpts[k])
+ interval.append(endpts[k + 1])
+ intervals.append(interval)
+ # add +inf to intervals
+ intervals.append([endpts[length - 1], float("inf")])
+ return intervals
+
+
+def hide_tick_labels_from_box_subplots(fig):
+ """
+ Hides tick labels for box plots in scatterplotmatrix subplots.
+ """
+ boxplot_xaxes = []
+ for trace in fig["data"]:
+ if trace["type"] == "box":
+ # stores the xaxes which correspond to boxplot subplots
+ # since we use xaxis1, xaxis2, etc, in plotly.py
+ boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:]))
+ for xaxis in boxplot_xaxes:
+ fig["layout"][xaxis]["showticklabels"] = False
+
+
+def validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs):
+ """
+ Validates basic inputs for FigureFactory.create_scatterplotmatrix()
+
+ :raises: (PlotlyError) If pandas is not imported
+ :raises: (PlotlyError) If pandas dataframe is not inputted
+ :raises: (PlotlyError) If pandas dataframe has <= 1 columns
+ :raises: (PlotlyError) If diagonal plot choice (diag) is not one of
+ the viable options
+ :raises: (PlotlyError) If colormap_type is not a valid choice
+ :raises: (PlotlyError) If kwargs contains 'size', 'color' or
+ 'colorscale'
+ """
+ if not pd:
+ raise ImportError(
+ "FigureFactory.scatterplotmatrix requires a pandas DataFrame."
+ )
+
+ # Check if pandas dataframe
+ if not isinstance(df, pd.core.frame.DataFrame):
+ raise exceptions.PlotlyError(
+ "Dataframe not inputed. Please "
+ "use a pandas dataframe to pro"
+ "duce a scatterplot matrix."
+ )
+
+ # Check if dataframe is 1 column or less
+ if len(df.columns) <= 1:
+ raise exceptions.PlotlyError(
+ "Dataframe has only one column. To "
+ "use the scatterplot matrix, use at "
+ "least 2 columns."
+ )
+
+ # Check that diag parameter is a valid selection
+ if diag not in DIAG_CHOICES:
+ raise exceptions.PlotlyError(
+ "Make sure diag is set to one of {}".format(DIAG_CHOICES)
+ )
+
+ # Check that colormap_types is a valid selection
+ if colormap_type not in VALID_COLORMAP_TYPES:
+ raise exceptions.PlotlyError(
+ "Must choose a valid colormap type. "
+ "Either 'cat' or 'seq' for a cate"
+ "gorical and sequential colormap "
+ "respectively."
+ )
+
+ # Check for not 'size' or 'color' in 'marker' of **kwargs
+ if "marker" in kwargs:
+ FORBIDDEN_PARAMS = ["size", "color", "colorscale"]
+ if any(param in kwargs["marker"] for param in FORBIDDEN_PARAMS):
+ raise exceptions.PlotlyError(
+ "Your kwargs dictionary cannot "
+ "include the 'size', 'color' or "
+ "'colorscale' key words inside "
+ "the marker dict since 'size' is "
+ "already an argument of the "
+ "scatterplot matrix function and "
+ "both 'color' and 'colorscale "
+ "are set internally."
+ )
+
+
+def scatterplot(dataframe, headers, diag, size, height, width, title, **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix without index
+
+ """
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ # Insert traces into trace_list
+ for listy in dataframe:
+ for listx in dataframe:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(x=listx, showlegend=False)
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(y=listx, name=None, showlegend=False)
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ trace = graph_objs.Scatter(
+ x=listx, y=listy, mode="markers", showlegend=False, **kwargs
+ )
+ trace_list.append(trace)
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode="markers",
+ marker=dict(size=size),
+ showlegend=False,
+ **kwargs,
+ )
+ trace_list.append(trace)
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ fig.append_trace(trace_list[trace_index], y_index, x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j)
+ fig["layout"][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = "yaxis{}".format(1 + (dim * j))
+ fig["layout"][yaxis_key].update(title=headers[j])
+
+ fig["layout"].update(height=height, width=width, title=title, showlegend=True)
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ return fig
+
+
+def scatterplot_dict(
+ dataframe,
+ headers,
+ diag,
+ size,
+ height,
+ width,
+ title,
+ index,
+ index_vals,
+ endpts,
+ colormap,
+ colormap_type,
+ **kwargs,
+):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix with both index and colormap picked.
+ Used if colormap is a dictionary with index values as keys pointing to
+ colors. Forces colormap_type to behave categorically because it would
+ not make sense colors are assigned to each index value and thus
+ implies that a categorical approach should be taken
+
+ """
+
+ theme = colormap
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # create a dictionary for index_vals
+ unique_index_vals = {}
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals[name] = []
+
+ # Fill all the rest of the names into the dictionary
+ for name in sorted(unique_index_vals.keys()):
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if index_vals[j] == name:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=new_listx, marker=dict(color=theme[name]), showlegend=True
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(color=theme[name]),
+ showlegend=True,
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ kwargs["marker"]["color"] = theme[name]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ showlegend=True,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ marker=dict(size=size, color=theme[name]),
+ showlegend=True,
+ **kwargs,
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(color=theme[name]),
+ showlegend=False,
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(color=theme[name]),
+ showlegend=False,
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ kwargs["marker"]["color"] = theme[name]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ showlegend=False,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ marker=dict(size=size, color=theme[name]),
+ showlegend=False,
+ **kwargs,
+ )
+ # Push the trace into dictionary
+ unique_index_vals[name] = trace
+ trace_list.append(unique_index_vals)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for name in sorted(trace_list[trace_index].keys()):
+ fig.append_trace(trace_list[trace_index][name], y_index, x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j)
+ fig["layout"][xaxis_key].update(title=headers[j])
+
+ for j in range(dim):
+ yaxis_key = "yaxis{}".format(1 + (dim * j))
+ fig["layout"][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == "histogram":
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True, barmode="stack"
+ )
+ return fig
+
+ else:
+ fig["layout"].update(height=height, width=width, title=title, showlegend=True)
+ return fig
+
+
+def scatterplot_theme(
+ dataframe,
+ headers,
+ diag,
+ size,
+ height,
+ width,
+ title,
+ index,
+ index_vals,
+ endpts,
+ colormap,
+ colormap_type,
+ **kwargs,
+):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix with both index and colormap picked
+
+ """
+
+ # Check if index is made of string values
+ if isinstance(index_vals[0], str):
+ unique_index_vals = []
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals.append(name)
+ n_colors_len = len(unique_index_vals)
+
+ # Convert colormap to list of n RGB tuples
+ if colormap_type == "seq":
+ foo = clrs.color_parser(colormap, clrs.unlabel_rgb)
+ foo = clrs.n_colors(foo[0], foo[1], n_colors_len)
+ theme = clrs.color_parser(foo, clrs.label_rgb)
+
+ if colormap_type == "cat":
+ # leave list of colors the same way
+ theme = colormap
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # create a dictionary for index_vals
+ unique_index_vals = {}
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals[name] = []
+
+ c_indx = 0 # color index
+ # Fill all the rest of the names into the dictionary
+ for name in sorted(unique_index_vals.keys()):
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if index_vals[j] == name:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(color=theme[c_indx]),
+ showlegend=True,
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(color=theme[c_indx]),
+ showlegend=True,
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ kwargs["marker"]["color"] = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ showlegend=True,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ marker=dict(size=size, color=theme[c_indx]),
+ showlegend=True,
+ **kwargs,
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(color=theme[c_indx]),
+ showlegend=False,
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(color=theme[c_indx]),
+ showlegend=False,
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ kwargs["marker"]["color"] = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ showlegend=False,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=name,
+ marker=dict(size=size, color=theme[c_indx]),
+ showlegend=False,
+ **kwargs,
+ )
+ # Push the trace into dictionary
+ unique_index_vals[name] = trace
+ if c_indx >= (len(theme) - 1):
+ c_indx = -1
+ c_indx += 1
+ trace_list.append(unique_index_vals)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for name in sorted(trace_list[trace_index].keys()):
+ fig.append_trace(trace_list[trace_index][name], y_index, x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j)
+ fig["layout"][xaxis_key].update(title=headers[j])
+
+ for j in range(dim):
+ yaxis_key = "yaxis{}".format(1 + (dim * j))
+ fig["layout"][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == "histogram":
+ fig["layout"].update(
+ height=height,
+ width=width,
+ title=title,
+ showlegend=True,
+ barmode="stack",
+ )
+ return fig
+
+ elif diag == "box":
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True
+ )
+ return fig
+
+ else:
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True
+ )
+ return fig
+
+ else:
+ if endpts:
+ intervals = utils.endpts_to_intervals(endpts)
+
+ # Convert colormap to list of n RGB tuples
+ if colormap_type == "seq":
+ foo = clrs.color_parser(colormap, clrs.unlabel_rgb)
+ foo = clrs.n_colors(foo[0], foo[1], len(intervals))
+ theme = clrs.color_parser(foo, clrs.label_rgb)
+
+ if colormap_type == "cat":
+ # leave list of colors the same way
+ theme = colormap
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ interval_labels = {}
+ for interval in intervals:
+ interval_labels[str(interval)] = []
+
+ c_indx = 0 # color index
+ # Fill all the rest of the names into the dictionary
+ for interval in intervals:
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if interval[0] < index_vals[j] <= interval[1]:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(color=theme[c_indx]),
+ showlegend=True,
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(color=theme[c_indx]),
+ showlegend=True,
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ (kwargs["marker"]["color"]) = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=str(interval),
+ showlegend=True,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=str(interval),
+ marker=dict(size=size, color=theme[c_indx]),
+ showlegend=True,
+ **kwargs,
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(color=theme[c_indx]),
+ showlegend=False,
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(color=theme[c_indx]),
+ showlegend=False,
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ (kwargs["marker"]["color"]) = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=str(interval),
+ showlegend=False,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode="markers",
+ name=str(interval),
+ marker=dict(size=size, color=theme[c_indx]),
+ showlegend=False,
+ **kwargs,
+ )
+ # Push the trace into dictionary
+ interval_labels[str(interval)] = trace
+ if c_indx >= (len(theme) - 1):
+ c_indx = -1
+ c_indx += 1
+ trace_list.append(interval_labels)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for interval in intervals:
+ fig.append_trace(
+ trace_list[trace_index][str(interval)], y_index, x_index
+ )
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j)
+ fig["layout"][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = "yaxis{}".format(1 + (dim * j))
+ fig["layout"][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == "histogram":
+ fig["layout"].update(
+ height=height,
+ width=width,
+ title=title,
+ showlegend=True,
+ barmode="stack",
+ )
+ return fig
+
+ elif diag == "box":
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True
+ )
+ return fig
+
+ else:
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True
+ )
+ return fig
+
+ else:
+ theme = colormap
+
+ # add a copy of rgb color to theme if it contains one color
+ if len(theme) <= 1:
+ theme.append(theme[0])
+
+ color = []
+ for incr in range(len(theme)):
+ color.append([1.0 / (len(theme) - 1) * incr, theme[incr]])
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Run through all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=listx, marker=dict(color=theme[0]), showlegend=False
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=listx, marker=dict(color=theme[0]), showlegend=False
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ kwargs["marker"]["color"] = index_vals
+ kwargs["marker"]["colorscale"] = color
+ kwargs["marker"]["showscale"] = True
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode="markers",
+ showlegend=False,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=index_vals,
+ colorscale=color,
+ showscale=True,
+ ),
+ showlegend=False,
+ **kwargs,
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == "histogram"):
+ trace = graph_objs.Histogram(
+ x=listx, marker=dict(color=theme[0]), showlegend=False
+ )
+ elif (listx == listy) and (diag == "box"):
+ trace = graph_objs.Box(
+ y=listx, marker=dict(color=theme[0]), showlegend=False
+ )
+ else:
+ if "marker" in kwargs:
+ kwargs["marker"]["size"] = size
+ kwargs["marker"]["color"] = index_vals
+ kwargs["marker"]["colorscale"] = color
+ kwargs["marker"]["showscale"] = False
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode="markers",
+ showlegend=False,
+ **kwargs,
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=index_vals,
+ colorscale=color,
+ showscale=False,
+ ),
+ showlegend=False,
+ **kwargs,
+ )
+ # Push the trace into list
+ trace_list.append(trace)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ fig.append_trace(trace_list[trace_index], y_index, x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j)
+ fig["layout"][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = "yaxis{}".format(1 + (dim * j))
+ fig["layout"][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == "histogram":
+ fig["layout"].update(
+ height=height,
+ width=width,
+ title=title,
+ showlegend=True,
+ barmode="stack",
+ )
+ return fig
+
+ elif diag == "box":
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True
+ )
+ return fig
+
+ else:
+ fig["layout"].update(
+ height=height, width=width, title=title, showlegend=True
+ )
+ return fig
+
+
+def create_scatterplotmatrix(
+ df,
+ index=None,
+ endpts=None,
+ diag="scatter",
+ height=500,
+ width=500,
+ size=6,
+ title="Scatterplot Matrix",
+ colormap=None,
+ colormap_type="cat",
+ dataframe=None,
+ headers=None,
+ index_vals=None,
+ **kwargs,
+):
+ """
+ Returns data for a scatterplot matrix;
+ **deprecated**,
+ use instead the plotly.graph_objects trace
+ :class:`plotly.graph_objects.Splom`.
+
+ :param (array) df: array of the data with column headers
+ :param (str) index: name of the index column in data array
+ :param (list|tuple) endpts: takes an increasing sequece of numbers
+ that defines intervals on the real line. They are used to group
+ the entries in an index of numbers into their corresponding
+ interval and therefore can be treated as categorical data
+ :param (str) diag: sets the chart type for the main diagonal plots.
+ The options are 'scatter', 'histogram' and 'box'.
+ :param (int|float) height: sets the height of the chart
+ :param (int|float) width: sets the width of the chart
+ :param (float) size: sets the marker size (in px)
+ :param (str) title: the title label of the scatterplot matrix
+ :param (str|tuple|list|dict) colormap: either a plotly scale name,
+ an rgb or hex color, a color tuple, a list of colors or a
+ dictionary. An rgb color is of the form 'rgb(x, y, z)' where
+ x, y and z belong to the interval [0, 255] and a color tuple is a
+ tuple of the form (a, b, c) where a, b and c belong to [0, 1].
+ If colormap is a list, it must contain valid color types as its
+ members.
+ If colormap is a dictionary, all the string entries in
+ the index column must be a key in colormap. In this case, the
+ colormap_type is forced to 'cat' or categorical
+ :param (str) colormap_type: determines how colormap is interpreted.
+ Valid choices are 'seq' (sequential) and 'cat' (categorical). If
+ 'seq' is selected, only the first two colors in colormap will be
+ considered (when colormap is a list) and the index values will be
+ linearly interpolated between those two colors. This option is
+ forced if all index values are numeric.
+ If 'cat' is selected, a color from colormap will be assigned to
+ each category from index, including the intervals if endpts is
+ being used
+ :param (dict) **kwargs: a dictionary of scatterplot arguments
+ The only forbidden parameters are 'size', 'color' and
+ 'colorscale' in 'marker'
+
+ Example 1: Vanilla Scatterplot Matrix
+
+ >>> from plotly.graph_objs import graph_objs
+ >>> from plotly.figure_factory import create_scatterplotmatrix
+
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> # Create dataframe
+ >>> df = pd.DataFrame(np.random.randn(10, 2),
+ ... columns=['Column 1', 'Column 2'])
+
+ >>> # Create scatterplot matrix
+ >>> fig = create_scatterplotmatrix(df)
+ >>> fig.show()
+
+
+ Example 2: Indexing a Column
+
+ >>> from plotly.graph_objs import graph_objs
+ >>> from plotly.figure_factory import create_scatterplotmatrix
+
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> # Create dataframe with index
+ >>> df = pd.DataFrame(np.random.randn(10, 2),
+ ... columns=['A', 'B'])
+
+ >>> # Add another column of strings to the dataframe
+ >>> df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
+ ... 'grape', 'pear', 'pear', 'apple', 'pear'])
+
+ >>> # Create scatterplot matrix
+ >>> fig = create_scatterplotmatrix(df, index='Fruit', size=10)
+ >>> fig.show()
+
+
+ Example 3: Styling the Diagonal Subplots
+
+ >>> from plotly.graph_objs import graph_objs
+ >>> from plotly.figure_factory import create_scatterplotmatrix
+
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> # Create dataframe with index
+ >>> df = pd.DataFrame(np.random.randn(10, 4),
+ ... columns=['A', 'B', 'C', 'D'])
+
+ >>> # Add another column of strings to the dataframe
+ >>> df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
+ ... 'grape', 'pear', 'pear', 'apple', 'pear'])
+
+ >>> # Create scatterplot matrix
+ >>> fig = create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000,
+ ... width=1000)
+ >>> fig.show()
+
+
+ Example 4: Use a Theme to Style the Subplots
+
+ >>> from plotly.graph_objs import graph_objs
+ >>> from plotly.figure_factory import create_scatterplotmatrix
+
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> # Create dataframe with random data
+ >>> df = pd.DataFrame(np.random.randn(100, 3),
+ ... columns=['A', 'B', 'C'])
+
+ >>> # Create scatterplot matrix using a built-in
+ >>> # Plotly palette scale and indexing column 'A'
+ >>> fig = create_scatterplotmatrix(df, diag='histogram', index='A',
+ ... colormap='Blues', height=800, width=800)
+ >>> fig.show()
+
+
+ Example 5: Example 4 with Interval Factoring
+
+ >>> from plotly.graph_objs import graph_objs
+ >>> from plotly.figure_factory import create_scatterplotmatrix
+
+ >>> import numpy as np
+ >>> import pandas as pd
+
+ >>> # Create dataframe with random data
+ >>> df = pd.DataFrame(np.random.randn(100, 3),
+ ... columns=['A', 'B', 'C'])
+
+ >>> # Create scatterplot matrix using a list of 2 rgb tuples
+ >>> # and endpoints at -1, 0 and 1
+ >>> fig = create_scatterplotmatrix(df, diag='histogram', index='A',
+ ... colormap=['rgb(140, 255, 50)',
+ ... 'rgb(170, 60, 115)', '#6c4774',
+ ... (0.5, 0.1, 0.8)],
+ ... endpts=[-1, 0, 1], height=800, width=800)
+ >>> fig.show()
+
+
+ Example 6: Using the colormap as a Dictionary
+
+ >>> from plotly.graph_objs import graph_objs
+ >>> from plotly.figure_factory import create_scatterplotmatrix
+
+ >>> import numpy as np
+ >>> import pandas as pd
+ >>> import random
+
+ >>> # Create dataframe with random data
+ >>> df = pd.DataFrame(np.random.randn(100, 3),
+ ... columns=['Column A',
+ ... 'Column B',
+ ... 'Column C'])
+
+ >>> # Add new color column to dataframe
+ >>> new_column = []
+ >>> strange_colors = ['turquoise', 'limegreen', 'goldenrod']
+
+ >>> for j in range(100):
+ ... new_column.append(random.choice(strange_colors))
+ >>> df['Colors'] = pd.Series(new_column, index=df.index)
+
+ >>> # Create scatterplot matrix using a dictionary of hex color values
+ >>> # which correspond to actual color names in 'Colors' column
+ >>> fig = create_scatterplotmatrix(
+ ... df, diag='box', index='Colors',
+ ... colormap= dict(
+ ... turquoise = '#00F5FF',
+ ... limegreen = '#32CD32',
+ ... goldenrod = '#DAA520'
+ ... ),
+ ... colormap_type='cat',
+ ... height=800, width=800
+ ... )
+ >>> fig.show()
+ """
+ # TODO: protected until #282
+ if dataframe is None:
+ dataframe = []
+ if headers is None:
+ headers = []
+ if index_vals is None:
+ index_vals = []
+
+ validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs)
+
+ # Validate colormap
+ if isinstance(colormap, dict):
+ colormap = clrs.validate_colors_dict(colormap, "rgb")
+ elif isinstance(colormap, str) and "rgb" not in colormap and "#" not in colormap:
+ if colormap not in clrs.PLOTLY_SCALES.keys():
+ raise exceptions.PlotlyError(
+ "If 'colormap' is a string, it must be the name "
+ "of a Plotly Colorscale. The available colorscale "
+ "names are {}".format(clrs.PLOTLY_SCALES.keys())
+ )
+ else:
+ # TODO change below to allow the correct Plotly colorscale
+ colormap = clrs.colorscale_to_colors(clrs.PLOTLY_SCALES[colormap])
+ # keep only first and last item - fix later
+ colormap = [colormap[0]] + [colormap[-1]]
+ colormap = clrs.validate_colors(colormap, "rgb")
+ else:
+ colormap = clrs.validate_colors(colormap, "rgb")
+
+ if not index:
+ for name in df:
+ headers.append(name)
+ for name in headers:
+ dataframe.append(df[name].values.tolist())
+ # Check for same data-type in df columns
+ utils.validate_dataframe(dataframe)
+ figure = scatterplot(
+ dataframe, headers, diag, size, height, width, title, **kwargs
+ )
+ return figure
+ else:
+ # Validate index selection
+ if index not in df:
+ raise exceptions.PlotlyError(
+ "Make sure you set the index "
+ "input variable to one of the "
+ "column names of your "
+ "dataframe."
+ )
+ index_vals = df[index].values.tolist()
+ for name in df:
+ if name != index:
+ headers.append(name)
+ for name in headers:
+ dataframe.append(df[name].values.tolist())
+
+ # check for same data-type in each df column
+ utils.validate_dataframe(dataframe)
+ utils.validate_index(index_vals)
+
+ # check if all colormap keys are in the index
+ # if colormap is a dictionary
+ if isinstance(colormap, dict):
+ for key in colormap:
+ if not all(index in colormap for index in index_vals):
+ raise exceptions.PlotlyError(
+ "If colormap is a "
+ "dictionary, all the "
+ "names in the index "
+ "must be keys."
+ )
+ figure = scatterplot_dict(
+ dataframe,
+ headers,
+ diag,
+ size,
+ height,
+ width,
+ title,
+ index,
+ index_vals,
+ endpts,
+ colormap,
+ colormap_type,
+ **kwargs,
+ )
+ return figure
+
+ else:
+ figure = scatterplot_theme(
+ dataframe,
+ headers,
+ diag,
+ size,
+ height,
+ width,
+ title,
+ index,
+ index_vals,
+ endpts,
+ colormap,
+ colormap_type,
+ **kwargs,
+ )
+ return figure
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py
new file mode 100644
index 0000000..55b74c3
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py
@@ -0,0 +1,406 @@
+import math
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+np = optional_imports.get_module("numpy")
+
+
+def validate_streamline(x, y):
+ """
+ Streamline-specific validations
+
+ Specifically, this checks that x and y are both evenly spaced,
+ and that the package numpy is available.
+
+ See FigureFactory.create_streamline() for params
+
+ :raises: (ImportError) If numpy is not available.
+ :raises: (PlotlyError) If x is not evenly spaced.
+ :raises: (PlotlyError) If y is not evenly spaced.
+ """
+ if np is False:
+ raise ImportError("FigureFactory.create_streamline requires numpy")
+ for index in range(len(x) - 1):
+ if ((x[index + 1] - x[index]) - (x[1] - x[0])) > 0.0001:
+ raise exceptions.PlotlyError(
+ "x must be a 1 dimensional, evenly spaced array"
+ )
+ for index in range(len(y) - 1):
+ if ((y[index + 1] - y[index]) - (y[1] - y[0])) > 0.0001:
+ raise exceptions.PlotlyError(
+ "y must be a 1 dimensional, evenly spaced array"
+ )
+
+
+def create_streamline(
+ x, y, u, v, density=1, angle=math.pi / 9, arrow_scale=0.09, **kwargs
+):
+ """
+ Returns data for a streamline plot.
+
+ :param (list|ndarray) x: 1 dimensional, evenly spaced list or array
+ :param (list|ndarray) y: 1 dimensional, evenly spaced list or array
+ :param (ndarray) u: 2 dimensional array
+ :param (ndarray) v: 2 dimensional array
+ :param (float|int) density: controls the density of streamlines in
+ plot. This is multiplied by 30 to scale similiarly to other
+ available streamline functions such as matplotlib.
+ Default = 1
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
+ Default = .09
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter
+ for more information on valid kwargs call
+ help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of streamline figure.
+
+ Example 1: Plot simple streamline and increase arrow size
+
+ >>> from plotly.figure_factory import create_streamline
+ >>> import plotly.graph_objects as go
+ >>> import numpy as np
+ >>> import math
+
+ >>> # Add data
+ >>> x = np.linspace(-3, 3, 100)
+ >>> y = np.linspace(-3, 3, 100)
+ >>> Y, X = np.meshgrid(x, y)
+ >>> u = -1 - X**2 + Y
+ >>> v = 1 + X - Y**2
+ >>> u = u.T # Transpose
+ >>> v = v.T # Transpose
+
+ >>> # Create streamline
+ >>> fig = create_streamline(x, y, u, v, arrow_scale=.1)
+ >>> fig.show()
+
+ Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
+
+ >>> from plotly.figure_factory import create_streamline
+ >>> import numpy as np
+ >>> import math
+
+ >>> # Add data
+ >>> N = 50
+ >>> x_start, x_end = -2.0, 2.0
+ >>> y_start, y_end = -1.0, 1.0
+ >>> x = np.linspace(x_start, x_end, N)
+ >>> y = np.linspace(y_start, y_end, N)
+ >>> X, Y = np.meshgrid(x, y)
+ >>> ss = 5.0
+ >>> x_s, y_s = -1.0, 0.0
+
+ >>> # Compute the velocity field on the mesh grid
+ >>> u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
+ >>> v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
+
+ >>> # Create streamline
+ >>> fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')
+
+ >>> # Add source point
+ >>> point = go.Scatter(x=[x_s], y=[y_s], mode='markers',
+ ... marker_size=14, name='source point')
+
+ >>> fig.add_trace(point) # doctest: +SKIP
+ >>> fig.show()
+ """
+ utils.validate_equal_length(x, y)
+ utils.validate_equal_length(u, v)
+ validate_streamline(x, y)
+ utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)
+
+ streamline_x, streamline_y = _Streamline(
+ x, y, u, v, density, angle, arrow_scale
+ ).sum_streamlines()
+ arrow_x, arrow_y = _Streamline(
+ x, y, u, v, density, angle, arrow_scale
+ ).get_streamline_arrows()
+
+ streamline = graph_objs.Scatter(
+ x=streamline_x + arrow_x, y=streamline_y + arrow_y, mode="lines", **kwargs
+ )
+
+ data = [streamline]
+ layout = graph_objs.Layout(hovermode="closest")
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Streamline(object):
+ """
+ Refer to FigureFactory.create_streamline() for docstring
+ """
+
+ def __init__(self, x, y, u, v, density, angle, arrow_scale, **kwargs):
+ self.x = np.array(x)
+ self.y = np.array(y)
+ self.u = np.array(u)
+ self.v = np.array(v)
+ self.angle = angle
+ self.arrow_scale = arrow_scale
+ self.density = int(30 * density) # Scale similarly to other functions
+ self.delta_x = self.x[1] - self.x[0]
+ self.delta_y = self.y[1] - self.y[0]
+ self.val_x = self.x
+ self.val_y = self.y
+
+ # Set up spacing
+ self.blank = np.zeros((self.density, self.density))
+ self.spacing_x = len(self.x) / float(self.density - 1)
+ self.spacing_y = len(self.y) / float(self.density - 1)
+ self.trajectories = []
+
+ # Rescale speed onto axes-coordinates
+ self.u = self.u / (self.x[-1] - self.x[0])
+ self.v = self.v / (self.y[-1] - self.y[0])
+ self.speed = np.sqrt(self.u**2 + self.v**2)
+
+ # Rescale u and v for integrations.
+ self.u *= len(self.x)
+ self.v *= len(self.y)
+ self.st_x = []
+ self.st_y = []
+ self.get_streamlines()
+ streamline_x, streamline_y = self.sum_streamlines()
+ arrows_x, arrows_y = self.get_streamline_arrows()
+
+ def blank_pos(self, xi, yi):
+ """
+ Set up positions for trajectories to be used with rk4 function.
+ """
+ return (int((xi / self.spacing_x) + 0.5), int((yi / self.spacing_y) + 0.5))
+
+ def value_at(self, a, xi, yi):
+ """
+ Set up for RK4 function, based on Bokeh's streamline code
+ """
+ if isinstance(xi, np.ndarray):
+ self.x = xi.astype(int)
+ self.y = yi.astype(int)
+ else:
+ self.val_x = int(xi)
+ self.val_y = int(yi)
+ a00 = a[self.val_y, self.val_x]
+ a01 = a[self.val_y, self.val_x + 1]
+ a10 = a[self.val_y + 1, self.val_x]
+ a11 = a[self.val_y + 1, self.val_x + 1]
+ xt = xi - self.val_x
+ yt = yi - self.val_y
+ a0 = a00 * (1 - xt) + a01 * xt
+ a1 = a10 * (1 - xt) + a11 * xt
+ return a0 * (1 - yt) + a1 * yt
+
+ def rk4_integrate(self, x0, y0):
+ """
+ RK4 forward and back trajectories from the initial conditions.
+
+ Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
+ x and y trajectories then checks length of traj (s in units of axes)
+ """
+
+ def f(xi, yi):
+ dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
+ ui = self.value_at(self.u, xi, yi)
+ vi = self.value_at(self.v, xi, yi)
+ return ui * dt_ds, vi * dt_ds
+
+ def g(xi, yi):
+ dt_ds = 1.0 / self.value_at(self.speed, xi, yi)
+ ui = self.value_at(self.u, xi, yi)
+ vi = self.value_at(self.v, xi, yi)
+ return -ui * dt_ds, -vi * dt_ds
+
+ def check(xi, yi):
+ return (0 <= xi < len(self.x) - 1) and (0 <= yi < len(self.y) - 1)
+
+ xb_changes = []
+ yb_changes = []
+
+ def rk4(x0, y0, f):
+ ds = 0.01
+ stotal = 0
+ xi = x0
+ yi = y0
+ xb, yb = self.blank_pos(xi, yi)
+ xf_traj = []
+ yf_traj = []
+ while check(xi, yi):
+ xf_traj.append(xi)
+ yf_traj.append(yi)
+ try:
+ k1x, k1y = f(xi, yi)
+ k2x, k2y = f(xi + 0.5 * ds * k1x, yi + 0.5 * ds * k1y)
+ k3x, k3y = f(xi + 0.5 * ds * k2x, yi + 0.5 * ds * k2y)
+ k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
+ except IndexError:
+ break
+ xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.0
+ yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.0
+ if not check(xi, yi):
+ break
+ stotal += ds
+ new_xb, new_yb = self.blank_pos(xi, yi)
+ if new_xb != xb or new_yb != yb:
+ if self.blank[new_yb, new_xb] == 0:
+ self.blank[new_yb, new_xb] = 1
+ xb_changes.append(new_xb)
+ yb_changes.append(new_yb)
+ xb = new_xb
+ yb = new_yb
+ else:
+ break
+ if stotal > 2:
+ break
+ return stotal, xf_traj, yf_traj
+
+ sf, xf_traj, yf_traj = rk4(x0, y0, f)
+ sb, xb_traj, yb_traj = rk4(x0, y0, g)
+ stotal = sf + sb
+ x_traj = xb_traj[::-1] + xf_traj[1:]
+ y_traj = yb_traj[::-1] + yf_traj[1:]
+
+ if len(x_traj) < 1:
+ return None
+ if stotal > 0.2:
+ initxb, inityb = self.blank_pos(x0, y0)
+ self.blank[inityb, initxb] = 1
+ return x_traj, y_traj
+ else:
+ for xb, yb in zip(xb_changes, yb_changes):
+ self.blank[yb, xb] = 0
+ return None
+
+ def traj(self, xb, yb):
+ """
+ Integrate trajectories
+
+ :param (int) xb: results of passing xi through self.blank_pos
+ :param (int) xy: results of passing yi through self.blank_pos
+
+ Calculate each trajectory based on rk4 integrate method.
+ """
+
+ if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
+ return
+ if self.blank[yb, xb] == 0:
+ t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
+ if t is not None:
+ self.trajectories.append(t)
+
+ def get_streamlines(self):
+ """
+ Get streamlines by building trajectory set.
+ """
+ for indent in range(self.density // 2):
+ for xi in range(self.density - 2 * indent):
+ self.traj(xi + indent, indent)
+ self.traj(xi + indent, self.density - 1 - indent)
+ self.traj(indent, xi + indent)
+ self.traj(self.density - 1 - indent, xi + indent)
+
+ self.st_x = [
+ np.array(t[0]) * self.delta_x + self.x[0] for t in self.trajectories
+ ]
+ self.st_y = [
+ np.array(t[1]) * self.delta_y + self.y[0] for t in self.trajectories
+ ]
+
+ for index in range(len(self.st_x)):
+ self.st_x[index] = self.st_x[index].tolist()
+ self.st_x[index].append(np.nan)
+
+ for index in range(len(self.st_y)):
+ self.st_y[index] = self.st_y[index].tolist()
+ self.st_y[index].append(np.nan)
+
+ def get_streamline_arrows(self):
+ """
+ Makes an arrow for each streamline.
+
+ Gets angle of streamline at 1/3 mark and creates arrow coordinates
+ based off of user defined angle and arrow_scale.
+
+ :param (array) st_x: x-values for all streamlines
+ :param (array) st_y: y-values for all streamlines
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
+ Default = .09
+ :rtype (list, list) arrows_x: x-values to create arrowhead and
+ arrows_y: y-values to create arrowhead
+ """
+ arrow_end_x = np.empty((len(self.st_x)))
+ arrow_end_y = np.empty((len(self.st_y)))
+ arrow_start_x = np.empty((len(self.st_x)))
+ arrow_start_y = np.empty((len(self.st_y)))
+ for index in range(len(self.st_x)):
+ arrow_end_x[index] = self.st_x[index][int(len(self.st_x[index]) / 3)]
+ arrow_start_x[index] = self.st_x[index][
+ (int(len(self.st_x[index]) / 3)) - 1
+ ]
+ arrow_end_y[index] = self.st_y[index][int(len(self.st_y[index]) / 3)]
+ arrow_start_y[index] = self.st_y[index][
+ (int(len(self.st_y[index]) / 3)) - 1
+ ]
+
+ dif_x = arrow_end_x - arrow_start_x
+ dif_y = arrow_end_y - arrow_start_y
+
+ orig_err = np.geterr()
+ np.seterr(divide="ignore", invalid="ignore")
+ streamline_ang = np.arctan(dif_y / dif_x)
+ np.seterr(**orig_err)
+
+ ang1 = streamline_ang + (self.angle)
+ ang2 = streamline_ang - (self.angle)
+
+ seg1_x = np.cos(ang1) * self.arrow_scale
+ seg1_y = np.sin(ang1) * self.arrow_scale
+ seg2_x = np.cos(ang2) * self.arrow_scale
+ seg2_y = np.sin(ang2) * self.arrow_scale
+
+ point1_x = np.empty((len(dif_x)))
+ point1_y = np.empty((len(dif_y)))
+ point2_x = np.empty((len(dif_x)))
+ point2_y = np.empty((len(dif_y)))
+
+ for index in range(len(dif_x)):
+ if dif_x[index] >= 0:
+ point1_x[index] = arrow_end_x[index] - seg1_x[index]
+ point1_y[index] = arrow_end_y[index] - seg1_y[index]
+ point2_x[index] = arrow_end_x[index] - seg2_x[index]
+ point2_y[index] = arrow_end_y[index] - seg2_y[index]
+ else:
+ point1_x[index] = arrow_end_x[index] + seg1_x[index]
+ point1_y[index] = arrow_end_y[index] + seg1_y[index]
+ point2_x[index] = arrow_end_x[index] + seg2_x[index]
+ point2_y[index] = arrow_end_y[index] + seg2_y[index]
+
+ space = np.empty((len(point1_x)))
+ space[:] = np.nan
+
+ # Combine arrays into array
+ arrows_x = np.array([point1_x, arrow_end_x, point2_x, space])
+ arrows_x = arrows_x.flatten("F")
+ arrows_x = arrows_x.tolist()
+
+ # Combine arrays into array
+ arrows_y = np.array([point1_y, arrow_end_y, point2_y, space])
+ arrows_y = arrows_y.flatten("F")
+ arrows_y = arrows_y.tolist()
+
+ return arrows_x, arrows_y
+
+ def sum_streamlines(self):
+ """
+ Makes all streamlines readable as a single trace.
+
+ :rtype (list, list): streamline_x: all x values for each streamline
+ combined into single list and streamline_y: all y values for each
+ streamline combined into single list
+ """
+ streamline_x = sum(self.st_x, [])
+ streamline_y = sum(self.st_y, [])
+ return streamline_x, streamline_y
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py
new file mode 100644
index 0000000..bda731f
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py
@@ -0,0 +1,280 @@
+from plotly import exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+pd = optional_imports.get_module("pandas")
+
+
+def validate_table(table_text, font_colors):
+ """
+ Table-specific validations
+
+ Check that font_colors is supplied correctly (1, 3, or len(text)
+ colors).
+
+ :raises: (PlotlyError) If font_colors is supplied incorretly.
+
+ See FigureFactory.create_table() for params
+ """
+ font_colors_len_options = [1, 3, len(table_text)]
+ if len(font_colors) not in font_colors_len_options:
+ raise exceptions.PlotlyError(
+ "Oops, font_colors should be a list of length 1, 3 or len(text)"
+ )
+
+
+def create_table(
+ table_text,
+ colorscale=None,
+ font_colors=None,
+ index=False,
+ index_title="",
+ annotation_offset=0.45,
+ height_constant=30,
+ hoverinfo="none",
+ **kwargs,
+):
+ """
+ Function that creates data tables.
+
+ See also the plotly.graph_objects trace
+ :class:`plotly.graph_objects.Table`
+
+ :param (pandas.Dataframe | list[list]) text: data for table.
+ :param (str|list[list]) colorscale: Colorscale for table where the
+ color at value 0 is the header color, .5 is the first table color
+ and 1 is the second table color. (Set .5 and 1 to avoid the striped
+ table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
+ [1, '#ffffff']]
+ :param (list) font_colors: Color for fonts in table. Can be a single
+ color, three colors, or a color for each row in the table.
+ Default=['#000000'] (black text for the entire table)
+ :param (int) height_constant: Constant multiplied by # of rows to
+ create table height. Default=30.
+ :param (bool) index: Create (header-colored) index column index from
+ Pandas dataframe or list[0] for each list in text. Default=False.
+ :param (string) index_title: Title for index column. Default=''.
+ :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
+ These kwargs describe other attributes about the annotated Heatmap
+ trace such as the colorscale. For more information on valid kwargs
+ call help(plotly.graph_objs.Heatmap)
+
+ Example 1: Simple Plotly Table
+
+ >>> from plotly.figure_factory import create_table
+
+ >>> text = [['Country', 'Year', 'Population'],
+ ... ['US', 2000, 282200000],
+ ... ['Canada', 2000, 27790000],
+ ... ['US', 2010, 309000000],
+ ... ['Canada', 2010, 34000000]]
+
+ >>> table = create_table(text)
+ >>> table.show()
+
+ Example 2: Table with Custom Coloring
+
+ >>> from plotly.figure_factory import create_table
+ >>> text = [['Country', 'Year', 'Population'],
+ ... ['US', 2000, 282200000],
+ ... ['Canada', 2000, 27790000],
+ ... ['US', 2010, 309000000],
+ ... ['Canada', 2010, 34000000]]
+ >>> table = create_table(text,
+ ... colorscale=[[0, '#000000'],
+ ... [.5, '#80beff'],
+ ... [1, '#cce5ff']],
+ ... font_colors=['#ffffff', '#000000',
+ ... '#000000'])
+ >>> table.show()
+
+ Example 3: Simple Plotly Table with Pandas
+
+ >>> from plotly.figure_factory import create_table
+ >>> import pandas as pd
+ >>> df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
+ >>> df_p = df[0:25]
+ >>> table_simple = create_table(df_p)
+ >>> table_simple.show()
+
+ """
+
+ # Avoiding mutables in the call signature
+ colorscale = (
+ colorscale
+ if colorscale is not None
+ else [[0, "#00083e"], [0.5, "#ededee"], [1, "#ffffff"]]
+ )
+ font_colors = (
+ font_colors if font_colors is not None else ["#ffffff", "#000000", "#000000"]
+ )
+
+ validate_table(table_text, font_colors)
+ table_matrix = _Table(
+ table_text,
+ colorscale,
+ font_colors,
+ index,
+ index_title,
+ annotation_offset,
+ **kwargs,
+ ).get_table_matrix()
+ annotations = _Table(
+ table_text,
+ colorscale,
+ font_colors,
+ index,
+ index_title,
+ annotation_offset,
+ **kwargs,
+ ).make_table_annotations()
+
+ trace = dict(
+ type="heatmap",
+ z=table_matrix,
+ opacity=0.75,
+ colorscale=colorscale,
+ showscale=False,
+ hoverinfo=hoverinfo,
+ **kwargs,
+ )
+
+ data = [trace]
+ layout = dict(
+ annotations=annotations,
+ height=len(table_matrix) * height_constant + 50,
+ margin=dict(t=0, b=0, r=0, l=0),
+ yaxis=dict(
+ autorange="reversed",
+ zeroline=False,
+ gridwidth=2,
+ ticks="",
+ dtick=1,
+ tick0=0.5,
+ showticklabels=False,
+ ),
+ xaxis=dict(
+ zeroline=False,
+ gridwidth=2,
+ ticks="",
+ dtick=1,
+ tick0=-0.5,
+ showticklabels=False,
+ ),
+ )
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Table(object):
+ """
+ Refer to TraceFactory.create_table() for docstring
+ """
+
+ def __init__(
+ self,
+ table_text,
+ colorscale,
+ font_colors,
+ index,
+ index_title,
+ annotation_offset,
+ **kwargs,
+ ):
+ if pd and isinstance(table_text, pd.DataFrame):
+ headers = table_text.columns.tolist()
+ table_text_index = table_text.index.tolist()
+ table_text = table_text.values.tolist()
+ table_text.insert(0, headers)
+ if index:
+ table_text_index.insert(0, index_title)
+ for i in range(len(table_text)):
+ table_text[i].insert(0, table_text_index[i])
+ self.table_text = table_text
+ self.colorscale = colorscale
+ self.font_colors = font_colors
+ self.index = index
+ self.annotation_offset = annotation_offset
+ self.x = range(len(table_text[0]))
+ self.y = range(len(table_text))
+
+ def get_table_matrix(self):
+ """
+ Create z matrix to make heatmap with striped table coloring
+
+ :rtype (list[list]) table_matrix: z matrix to make heatmap with striped
+ table coloring.
+ """
+ header = [0] * len(self.table_text[0])
+ odd_row = [0.5] * len(self.table_text[0])
+ even_row = [1] * len(self.table_text[0])
+ table_matrix = [None] * len(self.table_text)
+ table_matrix[0] = header
+ for i in range(1, len(self.table_text), 2):
+ table_matrix[i] = odd_row
+ for i in range(2, len(self.table_text), 2):
+ table_matrix[i] = even_row
+ if self.index:
+ for array in table_matrix:
+ array[0] = 0
+ return table_matrix
+
+ def get_table_font_color(self):
+ """
+ Fill font-color array.
+
+ Table text color can vary by row so this extends a single color or
+ creates an array to set a header color and two alternating colors to
+ create the striped table pattern.
+
+ :rtype (list[list]) all_font_colors: list of font colors for each row
+ in table.
+ """
+ if len(self.font_colors) == 1:
+ all_font_colors = self.font_colors * len(self.table_text)
+ elif len(self.font_colors) == 3:
+ all_font_colors = list(range(len(self.table_text)))
+ all_font_colors[0] = self.font_colors[0]
+ for i in range(1, len(self.table_text), 2):
+ all_font_colors[i] = self.font_colors[1]
+ for i in range(2, len(self.table_text), 2):
+ all_font_colors[i] = self.font_colors[2]
+ elif len(self.font_colors) == len(self.table_text):
+ all_font_colors = self.font_colors
+ else:
+ all_font_colors = ["#000000"] * len(self.table_text)
+ return all_font_colors
+
+ def make_table_annotations(self):
+ """
+ Generate annotations to fill in table text
+
+ :rtype (list) annotations: list of annotations for each cell of the
+ table.
+ """
+ all_font_colors = _Table.get_table_font_color(self)
+ annotations = []
+ for n, row in enumerate(self.table_text):
+ for m, val in enumerate(row):
+ # Bold text in header and index
+ format_text = (
+ "<b>" + str(val) + "</b>"
+ if n == 0 or self.index and m < 1
+ else str(val)
+ )
+ # Match font color of index to font color of header
+ font_color = (
+ self.font_colors[0] if self.index and m == 0 else all_font_colors[n]
+ )
+ annotations.append(
+ graph_objs.layout.Annotation(
+ text=format_text,
+ x=self.x[m] - self.annotation_offset,
+ y=self.y[n],
+ xref="x1",
+ yref="y1",
+ align="left",
+ xanchor="left",
+ font=dict(color=font_color),
+ showarrow=False,
+ )
+ )
+ return annotations
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py
new file mode 100644
index 0000000..4cdcf17
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py
@@ -0,0 +1,692 @@
+import plotly.colors as clrs
+from plotly.graph_objs import graph_objs as go
+from plotly import exceptions
+from plotly import optional_imports
+
+from skimage import measure
+
+np = optional_imports.get_module("numpy")
+scipy_interp = optional_imports.get_module("scipy.interpolate")
+
+# -------------------------- Layout ------------------------------
+
+
+def _ternary_layout(
+ title="Ternary contour plot", width=550, height=525, pole_labels=["a", "b", "c"]
+):
+ """
+ Layout of ternary contour plot, to be passed to ``go.FigureWidget``
+ object.
+
+ Parameters
+ ==========
+ title : str or None
+ Title of ternary plot
+ width : int
+ Figure width.
+ height : int
+ Figure height.
+ pole_labels : str, default ['a', 'b', 'c']
+ Names of the three poles of the triangle.
+ """
+ return dict(
+ title=title,
+ width=width,
+ height=height,
+ ternary=dict(
+ sum=1,
+ aaxis=dict(
+ title=dict(text=pole_labels[0]), min=0.01, linewidth=2, ticks="outside"
+ ),
+ baxis=dict(
+ title=dict(text=pole_labels[1]), min=0.01, linewidth=2, ticks="outside"
+ ),
+ caxis=dict(
+ title=dict(text=pole_labels[2]), min=0.01, linewidth=2, ticks="outside"
+ ),
+ ),
+ showlegend=False,
+ )
+
+
+# ------------- Transformations of coordinates -------------------
+
+
+def _replace_zero_coords(ternary_data, delta=0.0005):
+ """
+ Replaces zero ternary coordinates with delta and normalize the new
+ triplets (a, b, c).
+
+ Parameters
+ ----------
+
+ ternary_data : ndarray of shape (N, 3)
+
+ delta : float
+ Small float to regularize logarithm.
+
+ Notes
+ -----
+ Implements a method
+ by J. A. Martin-Fernandez, C. Barcelo-Vidal, V. Pawlowsky-Glahn,
+ Dealing with zeros and missing values in compositional data sets
+ using nonparametric imputation, Mathematical Geology 35 (2003),
+ pp 253-278.
+ """
+ zero_mask = ternary_data == 0
+ is_any_coord_zero = np.any(zero_mask, axis=0)
+
+ unity_complement = 1 - delta * is_any_coord_zero
+ if np.any(unity_complement) < 0:
+ raise ValueError(
+ "The provided value of delta led to negative"
+ "ternary coords.Set a smaller delta"
+ )
+ ternary_data = np.where(zero_mask, delta, unity_complement * ternary_data)
+ return ternary_data
+
+
+def _ilr_transform(barycentric):
+ """
+ Perform Isometric Log-Ratio on barycentric (compositional) data.
+
+ Parameters
+ ----------
+ barycentric: ndarray of shape (3, N)
+ Barycentric coordinates.
+
+ References
+ ----------
+ "An algebraic method to compute isometric logratio transformation and
+ back transformation of compositional data", Jarauta-Bragulat, E.,
+ Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
+ Intl Assoc for Math Geology, 2003, pp 31-30.
+ """
+ barycentric = np.asarray(barycentric)
+ x_0 = np.log(barycentric[0] / barycentric[1]) / np.sqrt(2)
+ x_1 = (
+ 1.0 / np.sqrt(6) * np.log(barycentric[0] * barycentric[1] / barycentric[2] ** 2)
+ )
+ ilr_tdata = np.stack((x_0, x_1))
+ return ilr_tdata
+
+
+def _ilr_inverse(x):
+ """
+ Perform inverse Isometric Log-Ratio (ILR) transform to retrieve
+ barycentric (compositional) data.
+
+ Parameters
+ ----------
+ x : array of shape (2, N)
+ Coordinates in ILR space.
+
+ References
+ ----------
+ "An algebraic method to compute isometric logratio transformation and
+ back transformation of compositional data", Jarauta-Bragulat, E.,
+ Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the
+ Intl Assoc for Math Geology, 2003, pp 31-30.
+ """
+ x = np.array(x)
+ matrix = np.array([[0.5, 1, 1.0], [-0.5, 1, 1.0], [0.0, 0.0, 1.0]])
+ s = np.sqrt(2) / 2
+ t = np.sqrt(3 / 2)
+ Sk = np.einsum("ik, kj -> ij", np.array([[s, t], [-s, t]]), x)
+ Z = -np.log(1 + np.exp(Sk).sum(axis=0))
+ log_barycentric = np.einsum(
+ "ik, kj -> ij", matrix, np.stack((2 * s * x[0], t * x[1], Z))
+ )
+ iilr_tdata = np.exp(log_barycentric)
+ return iilr_tdata
+
+
+def _transform_barycentric_cartesian():
+ """
+ Returns the transformation matrix from barycentric to Cartesian
+ coordinates and conversely.
+ """
+ # reference triangle
+ tri_verts = np.array([[0.5, np.sqrt(3) / 2], [0, 0], [1, 0]])
+ M = np.array([tri_verts[:, 0], tri_verts[:, 1], np.ones(3)])
+ return M, np.linalg.inv(M)
+
+
+def _prepare_barycentric_coord(b_coords):
+ """
+ Check ternary coordinates and return the right barycentric coordinates.
+ """
+ if not isinstance(b_coords, (list, np.ndarray)):
+ raise ValueError(
+ "Data should be either an array of shape (n,m),"
+ "or a list of n m-lists, m=2 or 3"
+ )
+ b_coords = np.asarray(b_coords)
+ if b_coords.shape[0] not in (2, 3):
+ raise ValueError(
+ "A point should have 2 (a, b) or 3 (a, b, c)barycentric coordinates"
+ )
+ if (
+ (len(b_coords) == 3)
+ and not np.allclose(b_coords.sum(axis=0), 1, rtol=0.01)
+ and not np.allclose(b_coords.sum(axis=0), 100, rtol=0.01)
+ ):
+ msg = "The sum of coordinates should be 1 or 100 for all data points"
+ raise ValueError(msg)
+
+ if len(b_coords) == 2:
+ A, B = b_coords
+ C = 1 - (A + B)
+ else:
+ A, B, C = b_coords / b_coords.sum(axis=0)
+ if np.any(np.stack((A, B, C)) < 0):
+ raise ValueError("Barycentric coordinates should be positive.")
+ return np.stack((A, B, C))
+
+
+def _compute_grid(coordinates, values, interp_mode="ilr"):
+ """
+ Transform data points with Cartesian or ILR mapping, then Compute
+ interpolation on a regular grid.
+
+ Parameters
+ ==========
+
+ coordinates : array-like
+ Barycentric coordinates of data points.
+ values : 1-d array-like
+ Data points, field to be represented as contours.
+ interp_mode : 'ilr' (default) or 'cartesian'
+ Defines how data are interpolated to compute contours.
+ """
+ if interp_mode == "cartesian":
+ M, invM = _transform_barycentric_cartesian()
+ coord_points = np.einsum("ik, kj -> ij", M, coordinates)
+ elif interp_mode == "ilr":
+ coordinates = _replace_zero_coords(coordinates)
+ coord_points = _ilr_transform(coordinates)
+ else:
+ raise ValueError("interp_mode should be cartesian or ilr")
+ xx, yy = coord_points[:2]
+ x_min, x_max = xx.min(), xx.max()
+ y_min, y_max = yy.min(), yy.max()
+ n_interp = max(200, int(np.sqrt(len(values))))
+ gr_x = np.linspace(x_min, x_max, n_interp)
+ gr_y = np.linspace(y_min, y_max, n_interp)
+ grid_x, grid_y = np.meshgrid(gr_x, gr_y)
+ # We use cubic interpolation, except outside of the convex hull
+ # of data points where we use nearest neighbor values.
+ grid_z = scipy_interp.griddata(
+ coord_points[:2].T, values, (grid_x, grid_y), method="cubic"
+ )
+ return grid_z, gr_x, gr_y
+
+
+# ----------------------- Contour traces ----------------------
+
+
+def _polygon_area(x, y):
+ return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
+
+
+def _colors(ncontours, colormap=None):
+ """
+ Return a list of ``ncontours`` colors from the ``colormap`` colorscale.
+ """
+ if colormap in clrs.PLOTLY_SCALES.keys():
+ cmap = clrs.PLOTLY_SCALES[colormap]
+ else:
+ raise exceptions.PlotlyError(
+ "Colorscale must be a valid Plotly Colorscale."
+ "The available colorscale names are {}".format(clrs.PLOTLY_SCALES.keys())
+ )
+ values = np.linspace(0, 1, ncontours)
+ vals_cmap = np.array([pair[0] for pair in cmap])
+ cols = np.array([pair[1] for pair in cmap])
+ inds = np.searchsorted(vals_cmap, values)
+ if "#" in cols[0]: # for Viridis
+ cols = [clrs.label_rgb(clrs.hex_to_rgb(col)) for col in cols]
+
+ colors = [cols[0]]
+ for ind, val in zip(inds[1:], values[1:]):
+ val1, val2 = vals_cmap[ind - 1], vals_cmap[ind]
+ interm = (val - val1) / (val2 - val1)
+ col = clrs.find_intermediate_color(
+ cols[ind - 1], cols[ind], interm, colortype="rgb"
+ )
+ colors.append(col)
+ return colors
+
+
+def _is_invalid_contour(x, y):
+ """
+ Utility function for _contour_trace
+
+ Contours with an area of the order as 1 pixel are considered spurious.
+ """
+ too_small = np.all(np.abs(x - x[0]) < 2) and np.all(np.abs(y - y[0]) < 2)
+ return too_small
+
+
+def _extract_contours(im, values, colors):
+ """
+ Utility function for _contour_trace.
+
+ In ``im`` only one part of the domain has valid values (corresponding
+ to a subdomain where barycentric coordinates are well defined). When
+ computing contours, we need to assign values outside of this domain.
+ We can choose a value either smaller than all the values inside the
+ valid domain, or larger. This value must be chose with caution so that
+ no spurious contours are added. For example, if the boundary of the valid
+ domain has large values and the outer value is set to a small one, all
+ intermediate contours will be added at the boundary.
+
+ Therefore, we compute the two sets of contours (with an outer value
+ smaller of larger than all values in the valid domain), and choose
+ the value resulting in a smaller total number of contours. There might
+ be a faster way to do this, but it works...
+ """
+ mask_nan = np.isnan(im)
+ im_min, im_max = (
+ im[np.logical_not(mask_nan)].min(),
+ im[np.logical_not(mask_nan)].max(),
+ )
+ zz_min = np.copy(im)
+ zz_min[mask_nan] = 2 * im_min
+ zz_max = np.copy(im)
+ zz_max[mask_nan] = 2 * im_max
+ all_contours1, all_values1, all_areas1, all_colors1 = [], [], [], []
+ all_contours2, all_values2, all_areas2, all_colors2 = [], [], [], []
+ for i, val in enumerate(values):
+ contour_level1 = measure.find_contours(zz_min, val)
+ contour_level2 = measure.find_contours(zz_max, val)
+ all_contours1.extend(contour_level1)
+ all_contours2.extend(contour_level2)
+ all_values1.extend([val] * len(contour_level1))
+ all_values2.extend([val] * len(contour_level2))
+ all_areas1.extend(
+ [_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level1]
+ )
+ all_areas2.extend(
+ [_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level2]
+ )
+ all_colors1.extend([colors[i]] * len(contour_level1))
+ all_colors2.extend([colors[i]] * len(contour_level2))
+ if len(all_contours1) <= len(all_contours2):
+ return all_contours1, all_values1, all_areas1, all_colors1
+ else:
+ return all_contours2, all_values2, all_areas2, all_colors2
+
+
+def _add_outer_contour(
+ all_contours,
+ all_values,
+ all_areas,
+ all_colors,
+ values,
+ val_outer,
+ v_min,
+ v_max,
+ colors,
+ color_min,
+ color_max,
+):
+ """
+ Utility function for _contour_trace
+
+ Adds the background color to fill gaps outside of computed contours.
+
+ To compute the background color, the color of the contour with largest
+ area (``val_outer``) is used. As background color, we choose the next
+ color value in the direction of the extrema of the colormap.
+
+ Then we add information for the outer contour for the different lists
+ provided as arguments.
+
+ A discrete colormap with all used colors is also returned (to be used
+ by colorscale trace).
+ """
+ # The exact value of outer contour is not used when defining the trace
+ outer_contour = 20 * np.array([[0, 0, 1], [0, 1, 0.5]]).T
+ all_contours = [outer_contour] + all_contours
+ delta_values = np.diff(values)[0]
+ values = np.concatenate(
+ ([values[0] - delta_values], values, [values[-1] + delta_values])
+ )
+ colors = np.concatenate(([color_min], colors, [color_max]))
+ index = np.nonzero(values == val_outer)[0][0]
+ if index < len(values) / 2:
+ index -= 1
+ else:
+ index += 1
+ all_colors = [colors[index]] + all_colors
+ all_values = [values[index]] + all_values
+ all_areas = [0] + all_areas
+ used_colors = [color for color in colors if color in all_colors]
+ # Define discrete colorscale
+ color_number = len(used_colors)
+ scale = np.linspace(0, 1, color_number + 1)
+ discrete_cm = []
+ for i, color in enumerate(used_colors):
+ discrete_cm.append([scale[i], used_colors[i]])
+ discrete_cm.append([scale[i + 1], used_colors[i]])
+ discrete_cm.append([scale[color_number], used_colors[color_number - 1]])
+
+ return all_contours, all_values, all_areas, all_colors, discrete_cm
+
+
+def _contour_trace(
+ x,
+ y,
+ z,
+ ncontours=None,
+ colorscale="Electric",
+ linecolor="rgb(150,150,150)",
+ interp_mode="llr",
+ coloring=None,
+ v_min=0,
+ v_max=1,
+):
+ """
+ Contour trace in Cartesian coordinates.
+
+ Parameters
+ ==========
+
+ x, y : array-like
+ Cartesian coordinates
+ z : array-like
+ Field to be represented as contours.
+ ncontours : int or None
+ Number of contours to display (determined automatically if None).
+ colorscale : None or str (Plotly colormap)
+ colorscale of the contours.
+ linecolor : rgb color
+ Color used for lines. If ``colorscale`` is not None, line colors are
+ determined from ``colorscale`` instead.
+ interp_mode : 'ilr' (default) or 'cartesian'
+ Defines how data are interpolated to compute contours. If 'irl',
+ ILR (Isometric Log-Ratio) of compositional data is performed. If
+ 'cartesian', contours are determined in Cartesian space.
+ coloring : None or 'lines'
+ How to display contour. Filled contours if None, lines if ``lines``.
+ vmin, vmax : float
+ Bounds of interval of values used for the colorspace
+
+ Notes
+ =====
+ """
+ # Prepare colors
+ # We do not take extrema, for example for one single contour
+ # the color will be the middle point of the colormap
+ colors = _colors(ncontours + 2, colorscale)
+ # Values used for contours, extrema are not used
+ # For example for a binary array [0, 1], the value of
+ # the contour for ncontours=1 is 0.5.
+ values = np.linspace(v_min, v_max, ncontours + 2)
+ color_min, color_max = colors[0], colors[-1]
+ colors = colors[1:-1]
+ values = values[1:-1]
+
+ # Color of line contours
+ if linecolor is None:
+ linecolor = "rgb(150, 150, 150)"
+ else:
+ colors = [linecolor] * ncontours
+
+ # Retrieve all contours
+ all_contours, all_values, all_areas, all_colors = _extract_contours(
+ z, values, colors
+ )
+
+ # Now sort contours by decreasing area
+ order = np.argsort(all_areas)[::-1]
+
+ # Add outer contour
+ all_contours, all_values, all_areas, all_colors, discrete_cm = _add_outer_contour(
+ all_contours,
+ all_values,
+ all_areas,
+ all_colors,
+ values,
+ all_values[order[0]],
+ v_min,
+ v_max,
+ colors,
+ color_min,
+ color_max,
+ )
+ order = np.concatenate(([0], order + 1))
+
+ # Compute traces, in the order of decreasing area
+ traces = []
+ M, invM = _transform_barycentric_cartesian()
+ dx = (x.max() - x.min()) / x.size
+ dy = (y.max() - y.min()) / y.size
+ for index in order:
+ y_contour, x_contour = all_contours[index].T
+ val = all_values[index]
+ if interp_mode == "cartesian":
+ bar_coords = np.dot(
+ invM,
+ np.stack((dx * x_contour, dy * y_contour, np.ones(x_contour.shape))),
+ )
+ elif interp_mode == "ilr":
+ bar_coords = _ilr_inverse(
+ np.stack((dx * x_contour + x.min(), dy * y_contour + y.min()))
+ )
+ if index == 0: # outer triangle
+ a = np.array([1, 0, 0])
+ b = np.array([0, 1, 0])
+ c = np.array([0, 0, 1])
+ else:
+ a, b, c = bar_coords
+ if _is_invalid_contour(x_contour, y_contour):
+ continue
+
+ _col = all_colors[index] if coloring == "lines" else linecolor
+ trace = dict(
+ type="scatterternary",
+ a=a,
+ b=b,
+ c=c,
+ mode="lines",
+ line=dict(color=_col, shape="spline", width=1),
+ fill="toself",
+ fillcolor=all_colors[index],
+ showlegend=True,
+ hoverinfo="skip",
+ name="%.3f" % val,
+ )
+ if coloring == "lines":
+ trace["fill"] = None
+ traces.append(trace)
+
+ return traces, discrete_cm
+
+
+# -------------------- Figure Factory for ternary contour -------------
+
+
+def create_ternary_contour(
+ coordinates,
+ values,
+ pole_labels=["a", "b", "c"],
+ width=500,
+ height=500,
+ ncontours=None,
+ showscale=False,
+ coloring=None,
+ colorscale="Bluered",
+ linecolor=None,
+ title=None,
+ interp_mode="ilr",
+ showmarkers=False,
+):
+ """
+ Ternary contour plot.
+
+ Parameters
+ ----------
+
+ coordinates : list or ndarray
+ Barycentric coordinates of shape (2, N) or (3, N) where N is the
+ number of data points. The sum of the 3 coordinates is expected
+ to be 1 for all data points.
+ values : array-like
+ Data points of field to be represented as contours.
+ pole_labels : str, default ['a', 'b', 'c']
+ Names of the three poles of the triangle.
+ width : int
+ Figure width.
+ height : int
+ Figure height.
+ ncontours : int or None
+ Number of contours to display (determined automatically if None).
+ showscale : bool, default False
+ If True, a colorbar showing the color scale is displayed.
+ coloring : None or 'lines'
+ How to display contour. Filled contours if None, lines if ``lines``.
+ colorscale : None or str (Plotly colormap)
+ colorscale of the contours.
+ linecolor : None or rgb color
+ Color used for lines. ``colorscale`` has to be set to None, otherwise
+ line colors are determined from ``colorscale``.
+ title : str or None
+ Title of ternary plot
+ interp_mode : 'ilr' (default) or 'cartesian'
+ Defines how data are interpolated to compute contours. If 'irl',
+ ILR (Isometric Log-Ratio) of compositional data is performed. If
+ 'cartesian', contours are determined in Cartesian space.
+ showmarkers : bool, default False
+ If True, markers corresponding to input compositional points are
+ superimposed on contours, using the same colorscale.
+
+ Examples
+ ========
+
+ Example 1: ternary contour plot with filled contours
+
+ >>> import plotly.figure_factory as ff
+ >>> import numpy as np
+ >>> # Define coordinates
+ >>> a, b = np.mgrid[0:1:20j, 0:1:20j]
+ >>> mask = a + b <= 1
+ >>> a = a[mask].ravel()
+ >>> b = b[mask].ravel()
+ >>> c = 1 - a - b
+ >>> # Values to be displayed as contours
+ >>> z = a * b * c
+ >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z)
+ >>> fig.show()
+
+ It is also possible to give only two barycentric coordinates for each
+ point, since the sum of the three coordinates is one:
+
+ >>> fig = ff.create_ternary_contour(np.stack((a, b)), z)
+
+
+ Example 2: ternary contour plot with line contours
+
+ >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines')
+
+ Example 3: customize number of contours
+
+ >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, ncontours=8)
+
+ Example 4: superimpose contour plot and original data as markers
+
+ >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines',
+ ... showmarkers=True)
+
+ Example 5: customize title and pole labels
+
+ >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z,
+ ... title='Ternary plot',
+ ... pole_labels=['clay', 'quartz', 'fledspar'])
+ """
+ if scipy_interp is None:
+ raise ImportError(
+ """\
+ The create_ternary_contour figure factory requires the scipy package"""
+ )
+ sk_measure = optional_imports.get_module("skimage")
+ if sk_measure is None:
+ raise ImportError(
+ """\
+ The create_ternary_contour figure factory requires the scikit-image
+ package"""
+ )
+ if colorscale is None:
+ showscale = False
+ if ncontours is None:
+ ncontours = 5
+ coordinates = _prepare_barycentric_coord(coordinates)
+ v_min, v_max = values.min(), values.max()
+ grid_z, gr_x, gr_y = _compute_grid(coordinates, values, interp_mode=interp_mode)
+
+ layout = _ternary_layout(
+ pole_labels=pole_labels, width=width, height=height, title=title
+ )
+
+ contour_trace, discrete_cm = _contour_trace(
+ gr_x,
+ gr_y,
+ grid_z,
+ ncontours=ncontours,
+ colorscale=colorscale,
+ linecolor=linecolor,
+ interp_mode=interp_mode,
+ coloring=coloring,
+ v_min=v_min,
+ v_max=v_max,
+ )
+
+ fig = go.Figure(data=contour_trace, layout=layout)
+
+ opacity = 1 if showmarkers else 0
+ a, b, c = coordinates
+ hovertemplate = (
+ pole_labels[0]
+ + ": %{a:.3f}<br>"
+ + pole_labels[1]
+ + ": %{b:.3f}<br>"
+ + pole_labels[2]
+ + ": %{c:.3f}<br>"
+ "z: %{marker.color:.3f}<extra></extra>"
+ )
+
+ fig.add_scatterternary(
+ a=a,
+ b=b,
+ c=c,
+ mode="markers",
+ marker={
+ "color": values,
+ "colorscale": colorscale,
+ "line": {"color": "rgb(120, 120, 120)", "width": int(coloring != "lines")},
+ },
+ opacity=opacity,
+ hovertemplate=hovertemplate,
+ )
+ if showscale:
+ if not showmarkers:
+ colorscale = discrete_cm
+ colorbar = dict(
+ {
+ "type": "scatterternary",
+ "a": [None],
+ "b": [None],
+ "c": [None],
+ "marker": {
+ "cmin": values.min(),
+ "cmax": values.max(),
+ "colorscale": colorscale,
+ "showscale": True,
+ },
+ "mode": "markers",
+ }
+ )
+ fig.add_trace(colorbar)
+
+ return fig
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py
new file mode 100644
index 0000000..f935292
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py
@@ -0,0 +1,509 @@
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.graph_objs import graph_objs
+
+np = optional_imports.get_module("numpy")
+
+
+def map_face2color(face, colormap, scale, vmin, vmax):
+ """
+ Normalize facecolor values by vmin/vmax and return rgb-color strings
+
+ This function takes a tuple color along with a colormap and a minimum
+ (vmin) and maximum (vmax) range of possible mean distances for the
+ given parametrized surface. It returns an rgb color based on the mean
+ distance between vmin and vmax
+
+ """
+ if vmin >= vmax:
+ raise exceptions.PlotlyError(
+ "Incorrect relation between vmin "
+ "and vmax. The vmin value cannot be "
+ "bigger than or equal to the value "
+ "of vmax."
+ )
+ if len(colormap) == 1:
+ # color each triangle face with the same color in colormap
+ face_color = colormap[0]
+ face_color = clrs.convert_to_RGB_255(face_color)
+ face_color = clrs.label_rgb(face_color)
+ return face_color
+ if face == vmax:
+ # pick last color in colormap
+ face_color = colormap[-1]
+ face_color = clrs.convert_to_RGB_255(face_color)
+ face_color = clrs.label_rgb(face_color)
+ return face_color
+ else:
+ if scale is None:
+ # find the normalized distance t of a triangle face between
+ # vmin and vmax where the distance is between 0 and 1
+ t = (face - vmin) / float((vmax - vmin))
+ low_color_index = int(t / (1.0 / (len(colormap) - 1)))
+
+ face_color = clrs.find_intermediate_color(
+ colormap[low_color_index],
+ colormap[low_color_index + 1],
+ t * (len(colormap) - 1) - low_color_index,
+ )
+
+ face_color = clrs.convert_to_RGB_255(face_color)
+ face_color = clrs.label_rgb(face_color)
+ else:
+ # find the face color for a non-linearly interpolated scale
+ t = (face - vmin) / float((vmax - vmin))
+
+ low_color_index = 0
+ for k in range(len(scale) - 1):
+ if scale[k] <= t < scale[k + 1]:
+ break
+ low_color_index += 1
+
+ low_scale_val = scale[low_color_index]
+ high_scale_val = scale[low_color_index + 1]
+
+ face_color = clrs.find_intermediate_color(
+ colormap[low_color_index],
+ colormap[low_color_index + 1],
+ (t - low_scale_val) / (high_scale_val - low_scale_val),
+ )
+
+ face_color = clrs.convert_to_RGB_255(face_color)
+ face_color = clrs.label_rgb(face_color)
+ return face_color
+
+
+def trisurf(
+ x,
+ y,
+ z,
+ simplices,
+ show_colorbar,
+ edges_color,
+ scale,
+ colormap=None,
+ color_func=None,
+ plot_edges=False,
+ x_edge=None,
+ y_edge=None,
+ z_edge=None,
+ facecolor=None,
+):
+ """
+ Refer to FigureFactory.create_trisurf() for docstring
+ """
+ # numpy import check
+ if not np:
+ raise ImportError("FigureFactory._trisurf() requires numpy imported.")
+ points3D = np.vstack((x, y, z)).T
+ simplices = np.atleast_2d(simplices)
+
+ # vertices of the surface triangles
+ tri_vertices = points3D[simplices]
+
+ # Define colors for the triangle faces
+ if color_func is None:
+ # mean values of z-coordinates of triangle vertices
+ mean_dists = tri_vertices[:, :, 2].mean(-1)
+ elif isinstance(color_func, (list, np.ndarray)):
+ # Pre-computed list / array of values to map onto color
+ if len(color_func) != len(simplices):
+ raise ValueError(
+ "If color_func is a list/array, it must "
+ "be the same length as simplices."
+ )
+
+ # convert all colors in color_func to rgb
+ for index in range(len(color_func)):
+ if isinstance(color_func[index], str):
+ if "#" in color_func[index]:
+ foo = clrs.hex_to_rgb(color_func[index])
+ color_func[index] = clrs.label_rgb(foo)
+
+ if isinstance(color_func[index], tuple):
+ foo = clrs.convert_to_RGB_255(color_func[index])
+ color_func[index] = clrs.label_rgb(foo)
+
+ mean_dists = np.asarray(color_func)
+ else:
+ # apply user inputted function to calculate
+ # custom coloring for triangle vertices
+ mean_dists = []
+ for triangle in tri_vertices:
+ dists = []
+ for vertex in triangle:
+ dist = color_func(vertex[0], vertex[1], vertex[2])
+ dists.append(dist)
+ mean_dists.append(np.mean(dists))
+ mean_dists = np.asarray(mean_dists)
+
+ # Check if facecolors are already strings and can be skipped
+ if isinstance(mean_dists[0], str):
+ facecolor = mean_dists
+ else:
+ min_mean_dists = np.min(mean_dists)
+ max_mean_dists = np.max(mean_dists)
+
+ if facecolor is None:
+ facecolor = []
+ for index in range(len(mean_dists)):
+ color = map_face2color(
+ mean_dists[index], colormap, scale, min_mean_dists, max_mean_dists
+ )
+ facecolor.append(color)
+
+ # Make sure facecolor is a list so output is consistent across Pythons
+ facecolor = np.asarray(facecolor)
+ ii, jj, kk = simplices.T
+
+ triangles = graph_objs.Mesh3d(
+ x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name=""
+ )
+
+ mean_dists_are_numbers = not isinstance(mean_dists[0], str)
+
+ if mean_dists_are_numbers and show_colorbar is True:
+ # make a colorscale from the colors
+ colorscale = clrs.make_colorscale(colormap, scale)
+ colorscale = clrs.convert_colorscale_to_rgb(colorscale)
+
+ colorbar = graph_objs.Scatter3d(
+ x=x[:1],
+ y=y[:1],
+ z=z[:1],
+ mode="markers",
+ marker=dict(
+ size=0.1,
+ color=[min_mean_dists, max_mean_dists],
+ colorscale=colorscale,
+ showscale=True,
+ ),
+ hoverinfo="none",
+ showlegend=False,
+ )
+
+ # the triangle sides are not plotted
+ if plot_edges is False:
+ if mean_dists_are_numbers and show_colorbar is True:
+ return [triangles, colorbar]
+ else:
+ return [triangles]
+
+ # define the lists x_edge, y_edge and z_edge, of x, y, resp z
+ # coordinates of edge end points for each triangle
+ # None separates data corresponding to two consecutive triangles
+ is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
+ if any(is_none):
+ if not all(is_none):
+ raise ValueError(
+ "If any (x_edge, y_edge, z_edge) is None, all must be None"
+ )
+ else:
+ x_edge = []
+ y_edge = []
+ z_edge = []
+
+ # Pull indices we care about, then add a None column to separate tris
+ ixs_triangles = [0, 1, 2, 0]
+ pull_edges = tri_vertices[:, ixs_triangles, :]
+ x_edge_pull = np.hstack(
+ [pull_edges[:, :, 0], np.tile(None, [pull_edges.shape[0], 1])]
+ )
+ y_edge_pull = np.hstack(
+ [pull_edges[:, :, 1], np.tile(None, [pull_edges.shape[0], 1])]
+ )
+ z_edge_pull = np.hstack(
+ [pull_edges[:, :, 2], np.tile(None, [pull_edges.shape[0], 1])]
+ )
+
+ # Now unravel the edges into a 1-d vector for plotting
+ x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
+ y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
+ z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
+
+ if not (len(x_edge) == len(y_edge) == len(z_edge)):
+ raise exceptions.PlotlyError(
+ "The lengths of x_edge, y_edge and z_edge are not the same."
+ )
+
+ # define the lines for plotting
+ lines = graph_objs.Scatter3d(
+ x=x_edge,
+ y=y_edge,
+ z=z_edge,
+ mode="lines",
+ line=graph_objs.scatter3d.Line(color=edges_color, width=1.5),
+ showlegend=False,
+ )
+
+ if mean_dists_are_numbers and show_colorbar is True:
+ return [triangles, lines, colorbar]
+ else:
+ return [triangles, lines]
+
+
+def create_trisurf(
+ x,
+ y,
+ z,
+ simplices,
+ colormap=None,
+ show_colorbar=True,
+ scale=None,
+ color_func=None,
+ title="Trisurf Plot",
+ plot_edges=True,
+ showbackground=True,
+ backgroundcolor="rgb(230, 230, 230)",
+ gridcolor="rgb(255, 255, 255)",
+ zerolinecolor="rgb(255, 255, 255)",
+ edges_color="rgb(50, 50, 50)",
+ height=800,
+ width=800,
+ aspectratio=None,
+):
+ """
+ Returns figure for a triangulated surface plot
+
+ :param (array) x: data values of x in a 1D array
+ :param (array) y: data values of y in a 1D array
+ :param (array) z: data values of z in a 1D array
+ :param (array) simplices: an array of shape (ntri, 3) where ntri is
+ the number of triangles in the triangularization. Each row of the
+ array contains the indicies of the verticies of each triangle
+ :param (str|tuple|list) colormap: either a plotly scale name, an rgb
+ or hex color, a color tuple or a list of colors. An rgb color is
+ of the form 'rgb(x, y, z)' where x, y, z belong to the interval
+ [0, 255] and a color tuple is a tuple of the form (a, b, c) where
+ a, b and c belong to [0, 1]. If colormap is a list, it must
+ contain the valid color types aforementioned as its members
+ :param (bool) show_colorbar: determines if colorbar is visible
+ :param (list|array) scale: sets the scale values to be used if a non-
+ linearly interpolated colormap is desired. If left as None, a
+ linear interpolation between the colors will be excecuted
+ :param (function|list) color_func: The parameter that determines the
+ coloring of the surface. Takes either a function with 3 arguments
+ x, y, z or a list/array of color values the same length as
+ simplices. If None, coloring will only depend on the z axis
+ :param (str) title: title of the plot
+ :param (bool) plot_edges: determines if the triangles on the trisurf
+ are visible
+ :param (bool) showbackground: makes background in plot visible
+ :param (str) backgroundcolor: color of background. Takes a string of
+ the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
+ :param (str) gridcolor: color of the gridlines besides the axes. Takes
+ a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
+ inclusive
+ :param (str) zerolinecolor: color of the axes. Takes a string of the
+ form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
+ :param (str) edges_color: color of the edges, if plot_edges is True
+ :param (int|float) height: the height of the plot (in pixels)
+ :param (int|float) width: the width of the plot (in pixels)
+ :param (dict) aspectratio: a dictionary of the aspect ratio values for
+ the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
+
+ Example 1: Sphere
+
+ >>> # Necessary Imports for Trisurf
+ >>> import numpy as np
+ >>> from scipy.spatial import Delaunay
+
+ >>> from plotly.figure_factory import create_trisurf
+ >>> from plotly.graph_objs import graph_objs
+
+ >>> # Make data for plot
+ >>> u = np.linspace(0, 2*np.pi, 20)
+ >>> v = np.linspace(0, np.pi, 20)
+ >>> u,v = np.meshgrid(u,v)
+ >>> u = u.flatten()
+ >>> v = v.flatten()
+
+ >>> x = np.sin(v)*np.cos(u)
+ >>> y = np.sin(v)*np.sin(u)
+ >>> z = np.cos(v)
+
+ >>> points2D = np.vstack([u,v]).T
+ >>> tri = Delaunay(points2D)
+ >>> simplices = tri.simplices
+
+ >>> # Create a figure
+ >>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
+ ... simplices=simplices)
+
+ Example 2: Torus
+
+ >>> # Necessary Imports for Trisurf
+ >>> import numpy as np
+ >>> from scipy.spatial import Delaunay
+
+ >>> from plotly.figure_factory import create_trisurf
+ >>> from plotly.graph_objs import graph_objs
+
+ >>> # Make data for plot
+ >>> u = np.linspace(0, 2*np.pi, 20)
+ >>> v = np.linspace(0, 2*np.pi, 20)
+ >>> u,v = np.meshgrid(u,v)
+ >>> u = u.flatten()
+ >>> v = v.flatten()
+
+ >>> x = (3 + (np.cos(v)))*np.cos(u)
+ >>> y = (3 + (np.cos(v)))*np.sin(u)
+ >>> z = np.sin(v)
+
+ >>> points2D = np.vstack([u,v]).T
+ >>> tri = Delaunay(points2D)
+ >>> simplices = tri.simplices
+
+ >>> # Create a figure
+ >>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
+ ... simplices=simplices)
+
+ Example 3: Mobius Band
+
+ >>> # Necessary Imports for Trisurf
+ >>> import numpy as np
+ >>> from scipy.spatial import Delaunay
+
+ >>> from plotly.figure_factory import create_trisurf
+ >>> from plotly.graph_objs import graph_objs
+
+ >>> # Make data for plot
+ >>> u = np.linspace(0, 2*np.pi, 24)
+ >>> v = np.linspace(-1, 1, 8)
+ >>> u,v = np.meshgrid(u,v)
+ >>> u = u.flatten()
+ >>> v = v.flatten()
+
+ >>> tp = 1 + 0.5*v*np.cos(u/2.)
+ >>> x = tp*np.cos(u)
+ >>> y = tp*np.sin(u)
+ >>> z = 0.5*v*np.sin(u/2.)
+
+ >>> points2D = np.vstack([u,v]).T
+ >>> tri = Delaunay(points2D)
+ >>> simplices = tri.simplices
+
+ >>> # Create a figure
+ >>> fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
+ ... simplices=simplices)
+
+ Example 4: Using a Custom Colormap Function with Light Cone
+
+ >>> # Necessary Imports for Trisurf
+ >>> import numpy as np
+ >>> from scipy.spatial import Delaunay
+
+ >>> from plotly.figure_factory import create_trisurf
+ >>> from plotly.graph_objs import graph_objs
+
+ >>> # Make data for plot
+ >>> u=np.linspace(-np.pi, np.pi, 30)
+ >>> v=np.linspace(-np.pi, np.pi, 30)
+ >>> u,v=np.meshgrid(u,v)
+ >>> u=u.flatten()
+ >>> v=v.flatten()
+
+ >>> x = u
+ >>> y = u*np.cos(v)
+ >>> z = u*np.sin(v)
+
+ >>> points2D = np.vstack([u,v]).T
+ >>> tri = Delaunay(points2D)
+ >>> simplices = tri.simplices
+
+ >>> # Define distance function
+ >>> def dist_origin(x, y, z):
+ ... return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
+
+ >>> # Create a figure
+ >>> fig1 = create_trisurf(x=x, y=y, z=z,
+ ... colormap=['#FFFFFF', '#E4FFFE',
+ ... '#A4F6F9', '#FF99FE',
+ ... '#BA52ED'],
+ ... scale=[0, 0.6, 0.71, 0.89, 1],
+ ... simplices=simplices,
+ ... color_func=dist_origin)
+
+ Example 5: Enter color_func as a list of colors
+
+ >>> # Necessary Imports for Trisurf
+ >>> import numpy as np
+ >>> from scipy.spatial import Delaunay
+ >>> import random
+
+ >>> from plotly.figure_factory import create_trisurf
+ >>> from plotly.graph_objs import graph_objs
+
+ >>> # Make data for plot
+ >>> u=np.linspace(-np.pi, np.pi, 30)
+ >>> v=np.linspace(-np.pi, np.pi, 30)
+ >>> u,v=np.meshgrid(u,v)
+ >>> u=u.flatten()
+ >>> v=v.flatten()
+
+ >>> x = u
+ >>> y = u*np.cos(v)
+ >>> z = u*np.sin(v)
+
+ >>> points2D = np.vstack([u,v]).T
+ >>> tri = Delaunay(points2D)
+ >>> simplices = tri.simplices
+
+
+ >>> colors = []
+ >>> color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
+
+ >>> for index in range(len(simplices)):
+ ... colors.append(random.choice(color_choices))
+
+ >>> fig = create_trisurf(
+ ... x, y, z, simplices,
+ ... color_func=colors,
+ ... show_colorbar=True,
+ ... edges_color='rgb(2, 85, 180)',
+ ... title=' Modern Art'
+ ... )
+ """
+ if aspectratio is None:
+ aspectratio = {"x": 1, "y": 1, "z": 1}
+
+ # Validate colormap
+ clrs.validate_colors(colormap)
+ colormap, scale = clrs.convert_colors_to_same_type(
+ colormap, colortype="tuple", return_default_colors=True, scale=scale
+ )
+
+ data1 = trisurf(
+ x,
+ y,
+ z,
+ simplices,
+ show_colorbar=show_colorbar,
+ color_func=color_func,
+ colormap=colormap,
+ scale=scale,
+ edges_color=edges_color,
+ plot_edges=plot_edges,
+ )
+
+ axis = dict(
+ showbackground=showbackground,
+ backgroundcolor=backgroundcolor,
+ gridcolor=gridcolor,
+ zerolinecolor=zerolinecolor,
+ )
+ layout = graph_objs.Layout(
+ title=title,
+ width=width,
+ height=height,
+ scene=graph_objs.layout.Scene(
+ xaxis=graph_objs.layout.scene.XAxis(**axis),
+ yaxis=graph_objs.layout.scene.YAxis(**axis),
+ zaxis=graph_objs.layout.scene.ZAxis(**axis),
+ aspectratio=dict(
+ x=aspectratio["x"], y=aspectratio["y"], z=aspectratio["z"]
+ ),
+ ),
+ )
+
+ return graph_objs.Figure(data=data1, layout=layout)
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py
new file mode 100644
index 0000000..55924e6
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py
@@ -0,0 +1,704 @@
+from numbers import Number
+
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.graph_objs import graph_objs
+from plotly.subplots import make_subplots
+
+pd = optional_imports.get_module("pandas")
+np = optional_imports.get_module("numpy")
+scipy_stats = optional_imports.get_module("scipy.stats")
+
+
+def calc_stats(data):
+ """
+ Calculate statistics for use in violin plot.
+ """
+ x = np.asarray(data, float)
+ vals_min = np.min(x)
+ vals_max = np.max(x)
+ q2 = np.percentile(x, 50, interpolation="linear")
+ q1 = np.percentile(x, 25, interpolation="lower")
+ q3 = np.percentile(x, 75, interpolation="higher")
+ iqr = q3 - q1
+ whisker_dist = 1.5 * iqr
+
+ # in order to prevent drawing whiskers outside the interval
+ # of data one defines the whisker positions as:
+ d1 = np.min(x[x >= (q1 - whisker_dist)])
+ d2 = np.max(x[x <= (q3 + whisker_dist)])
+ return {
+ "min": vals_min,
+ "max": vals_max,
+ "q1": q1,
+ "q2": q2,
+ "q3": q3,
+ "d1": d1,
+ "d2": d2,
+ }
+
+
+def make_half_violin(x, y, fillcolor="#1f77b4", linecolor="rgb(0, 0, 0)"):
+ """
+ Produces a sideways probability distribution fig violin plot.
+ """
+ text = [
+ "(pdf(y), y)=(" + "{:0.2f}".format(x[i]) + ", " + "{:0.2f}".format(y[i]) + ")"
+ for i in range(len(x))
+ ]
+
+ return graph_objs.Scatter(
+ x=x,
+ y=y,
+ mode="lines",
+ name="",
+ text=text,
+ fill="tonextx",
+ fillcolor=fillcolor,
+ line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape="spline"),
+ hoverinfo="text",
+ opacity=0.5,
+ )
+
+
+def make_violin_rugplot(vals, pdf_max, distance, color="#1f77b4"):
+ """
+ Returns a rugplot fig for a violin plot.
+ """
+ return graph_objs.Scatter(
+ y=vals,
+ x=[-pdf_max - distance] * len(vals),
+ marker=graph_objs.scatter.Marker(color=color, symbol="line-ew-open"),
+ mode="markers",
+ name="",
+ showlegend=False,
+ hoverinfo="y",
+ )
+
+
+def make_non_outlier_interval(d1, d2):
+ """
+ Returns the scatterplot fig of most of a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0, 0],
+ y=[d1, d2],
+ name="",
+ mode="lines",
+ line=graph_objs.scatter.Line(width=1.5, color="rgb(0,0,0)"),
+ )
+
+
+def make_quartiles(q1, q3):
+ """
+ Makes the upper and lower quartiles for a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0, 0],
+ y=[q1, q3],
+ text=[
+ "lower-quartile: " + "{:0.2f}".format(q1),
+ "upper-quartile: " + "{:0.2f}".format(q3),
+ ],
+ mode="lines",
+ line=graph_objs.scatter.Line(width=4, color="rgb(0,0,0)"),
+ hoverinfo="text",
+ )
+
+
+def make_median(q2):
+ """
+ Formats the 'median' hovertext for a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0],
+ y=[q2],
+ text=["median: " + "{:0.2f}".format(q2)],
+ mode="markers",
+ marker=dict(symbol="square", color="rgb(255,255,255)"),
+ hoverinfo="text",
+ )
+
+
+def make_XAxis(xaxis_title, xaxis_range):
+ """
+ Makes the x-axis for a violin plot.
+ """
+ xaxis = graph_objs.layout.XAxis(
+ title=xaxis_title,
+ range=xaxis_range,
+ showgrid=False,
+ zeroline=False,
+ showline=False,
+ mirror=False,
+ ticks="",
+ showticklabels=False,
+ )
+ return xaxis
+
+
+def make_YAxis(yaxis_title):
+ """
+ Makes the y-axis for a violin plot.
+ """
+ yaxis = graph_objs.layout.YAxis(
+ title=yaxis_title,
+ showticklabels=True,
+ autorange=True,
+ ticklen=4,
+ showline=True,
+ zeroline=False,
+ showgrid=False,
+ mirror=False,
+ )
+ return yaxis
+
+
+def violinplot(vals, fillcolor="#1f77b4", rugplot=True):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+ """
+ vals = np.asarray(vals, float)
+ # summary statistics
+ vals_min = calc_stats(vals)["min"]
+ vals_max = calc_stats(vals)["max"]
+ q1 = calc_stats(vals)["q1"]
+ q2 = calc_stats(vals)["q2"]
+ q3 = calc_stats(vals)["q3"]
+ d1 = calc_stats(vals)["d1"]
+ d2 = calc_stats(vals)["d2"]
+
+ # kernel density estimation of pdf
+ pdf = scipy_stats.gaussian_kde(vals)
+ # grid over the data interval
+ xx = np.linspace(vals_min, vals_max, 100)
+ # evaluate the pdf at the grid xx
+ yy = pdf(xx)
+ max_pdf = np.max(yy)
+ # distance from the violin plot to rugplot
+ distance = (2.0 * max_pdf) / 10 if rugplot else 0
+ # range for x values in the plot
+ plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
+ plot_data = [
+ make_half_violin(-yy, xx, fillcolor=fillcolor),
+ make_half_violin(yy, xx, fillcolor=fillcolor),
+ make_non_outlier_interval(d1, d2),
+ make_quartiles(q1, q3),
+ make_median(q2),
+ ]
+ if rugplot:
+ plot_data.append(
+ make_violin_rugplot(vals, max_pdf, distance=distance, color=fillcolor)
+ )
+ return plot_data, plot_xrange
+
+
+def violin_no_colorscale(
+ data,
+ data_header,
+ group_header,
+ colors,
+ use_colorscale,
+ group_stats,
+ rugplot,
+ sort,
+ height,
+ width,
+ title,
+):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot without colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ if sort:
+ group_name.sort()
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(
+ rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
+ )
+ color_index = 0
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], float)
+ if color_index >= len(colors):
+ color_index = 0
+ plot_data, plot_xrange = violinplot(
+ vals, fillcolor=colors[color_index], rugplot=rugplot
+ )
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+ color_index += 1
+
+ # add violin plot labels
+ fig["layout"].update(
+ {"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+
+ # set the sharey axis style
+ fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
+ fig["layout"].update(
+ title=title,
+ showlegend=False,
+ hovermode="closest",
+ autosize=False,
+ height=height,
+ width=width,
+ )
+
+ return fig
+
+
+def violin_colorscale(
+ data,
+ data_header,
+ group_header,
+ colors,
+ use_colorscale,
+ group_stats,
+ rugplot,
+ sort,
+ height,
+ width,
+ title,
+):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot with colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ if sort:
+ group_name.sort()
+
+ # make sure all group names are keys in group_stats
+ for group in group_name:
+ if group not in group_stats:
+ raise exceptions.PlotlyError(
+ "All values/groups in the index "
+ "column must be represented "
+ "as a key in group_stats."
+ )
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(
+ rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
+ )
+
+ # prepare low and high color for colorscale
+ lowcolor = clrs.color_parser(colors[0], clrs.unlabel_rgb)
+ highcolor = clrs.color_parser(colors[1], clrs.unlabel_rgb)
+
+ # find min and max values in group_stats
+ group_stats_values = []
+ for key in group_stats:
+ group_stats_values.append(group_stats[key])
+
+ max_value = max(group_stats_values)
+ min_value = min(group_stats_values)
+
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], float)
+
+ # find intermediate color from colorscale
+ intermed = (group_stats[gr] - min_value) / (max_value - min_value)
+ intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed)
+
+ plot_data, plot_xrange = violinplot(
+ vals, fillcolor="rgb{}".format(intermed_color), rugplot=rugplot
+ )
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+ fig["layout"].update(
+ {"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+ # add colorbar to plot
+ trace_dummy = graph_objs.Scatter(
+ x=[0],
+ y=[0],
+ mode="markers",
+ marker=dict(
+ size=2,
+ cmin=min_value,
+ cmax=max_value,
+ colorscale=[[0, colors[0]], [1, colors[1]]],
+ showscale=True,
+ ),
+ showlegend=False,
+ )
+ fig.append_trace(trace_dummy, 1, L)
+
+ # set the sharey axis style
+ fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
+ fig["layout"].update(
+ title=title,
+ showlegend=False,
+ hovermode="closest",
+ autosize=False,
+ height=height,
+ width=width,
+ )
+
+ return fig
+
+
+def violin_dict(
+ data,
+ data_header,
+ group_header,
+ colors,
+ use_colorscale,
+ group_stats,
+ rugplot,
+ sort,
+ height,
+ width,
+ title,
+):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot without colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+
+ if sort:
+ group_name.sort()
+
+ # check if all group names appear in colors dict
+ for group in group_name:
+ if group not in colors:
+ raise exceptions.PlotlyError(
+ "If colors is a dictionary, all "
+ "the group names must appear as "
+ "keys in colors."
+ )
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(
+ rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False
+ )
+
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], float)
+ plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr], rugplot=rugplot)
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+
+ # add violin plot labels
+ fig["layout"].update(
+ {"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+
+ # set the sharey axis style
+ fig["layout"].update({"yaxis{}".format(1): make_YAxis("")})
+ fig["layout"].update(
+ title=title,
+ showlegend=False,
+ hovermode="closest",
+ autosize=False,
+ height=height,
+ width=width,
+ )
+
+ return fig
+
+
+def create_violin(
+ data,
+ data_header=None,
+ group_header=None,
+ colors=None,
+ use_colorscale=False,
+ group_stats=None,
+ rugplot=True,
+ sort=False,
+ height=450,
+ width=600,
+ title="Violin and Rug Plot",
+):
+ """
+ **deprecated**, use instead the plotly.graph_objects trace
+ :class:`plotly.graph_objects.Violin`.
+
+ :param (list|array) data: accepts either a list of numerical values,
+ a list of dictionaries all with identical keys and at least one
+ column of numeric values, or a pandas dataframe with at least one
+ column of numbers.
+ :param (str) data_header: the header of the data column to be used
+ from an inputted pandas dataframe. Not applicable if 'data' is
+ a list of numeric values.
+ :param (str) group_header: applicable if grouping data by a variable.
+ 'group_header' must be set to the name of the grouping variable.
+ :param (str|tuple|list|dict) colors: either a plotly scale name,
+ an rgb or hex color, a color tuple, a list of colors or a
+ dictionary. An rgb color is of the form 'rgb(x, y, z)' where
+ x, y and z belong to the interval [0, 255] and a color tuple is a
+ tuple of the form (a, b, c) where a, b and c belong to [0, 1].
+ If colors is a list, it must contain valid color types as its
+ members.
+ :param (bool) use_colorscale: only applicable if grouping by another
+ variable. Will implement a colorscale based on the first 2 colors
+ of param colors. This means colors must be a list with at least 2
+ colors in it (Plotly colorscales are accepted since they map to a
+ list of two rgb colors). Default = False
+ :param (dict) group_stats: a dictionary where each key is a unique
+ value from the group_header column in data. Each value must be a
+ number and will be used to color the violin plots if a colorscale
+ is being used.
+ :param (bool) rugplot: determines if a rugplot is draw on violin plot.
+ Default = True
+ :param (bool) sort: determines if violins are sorted
+ alphabetically (True) or by input order (False). Default = False
+ :param (float) height: the height of the violin plot.
+ :param (float) width: the width of the violin plot.
+ :param (str) title: the title of the violin plot.
+
+ Example 1: Single Violin Plot
+
+ >>> from plotly.figure_factory import create_violin
+ >>> import plotly.graph_objs as graph_objects
+
+ >>> import numpy as np
+ >>> from scipy import stats
+
+ >>> # create list of random values
+ >>> data_list = np.random.randn(100)
+
+ >>> # create violin fig
+ >>> fig = create_violin(data_list, colors='#604d9e')
+
+ >>> # plot
+ >>> fig.show()
+
+ Example 2: Multiple Violin Plots with Qualitative Coloring
+
+ >>> from plotly.figure_factory import create_violin
+ >>> import plotly.graph_objs as graph_objects
+
+ >>> import numpy as np
+ >>> import pandas as pd
+ >>> from scipy import stats
+
+ >>> # create dataframe
+ >>> np.random.seed(619517)
+ >>> Nr=250
+ >>> y = np.random.randn(Nr)
+ >>> gr = np.random.choice(list("ABCDE"), Nr)
+ >>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
+
+ >>> for i, letter in enumerate("ABCDE"):
+ ... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
+ >>> df = pd.DataFrame(dict(Score=y, Group=gr))
+
+ >>> # create violin fig
+ >>> fig = create_violin(df, data_header='Score', group_header='Group',
+ ... sort=True, height=600, width=1000)
+
+ >>> # plot
+ >>> fig.show()
+
+ Example 3: Violin Plots with Colorscale
+
+ >>> from plotly.figure_factory import create_violin
+ >>> import plotly.graph_objs as graph_objects
+
+ >>> import numpy as np
+ >>> import pandas as pd
+ >>> from scipy import stats
+
+ >>> # create dataframe
+ >>> np.random.seed(619517)
+ >>> Nr=250
+ >>> y = np.random.randn(Nr)
+ >>> gr = np.random.choice(list("ABCDE"), Nr)
+ >>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
+
+ >>> for i, letter in enumerate("ABCDE"):
+ ... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
+ >>> df = pd.DataFrame(dict(Score=y, Group=gr))
+
+ >>> # define header params
+ >>> data_header = 'Score'
+ >>> group_header = 'Group'
+
+ >>> # make groupby object with pandas
+ >>> group_stats = {}
+ >>> groupby_data = df.groupby([group_header])
+
+ >>> for group in "ABCDE":
+ ... data_from_group = groupby_data.get_group(group)[data_header]
+ ... # take a stat of the grouped data
+ ... stat = np.median(data_from_group)
+ ... # add to dictionary
+ ... group_stats[group] = stat
+
+ >>> # create violin fig
+ >>> fig = create_violin(df, data_header='Score', group_header='Group',
+ ... height=600, width=1000, use_colorscale=True,
+ ... group_stats=group_stats)
+
+ >>> # plot
+ >>> fig.show()
+ """
+
+ # Validate colors
+ if isinstance(colors, dict):
+ valid_colors = clrs.validate_colors_dict(colors, "rgb")
+ else:
+ valid_colors = clrs.validate_colors(colors, "rgb")
+
+ # validate data and choose plot type
+ if group_header is None:
+ if isinstance(data, list):
+ if len(data) <= 0:
+ raise exceptions.PlotlyError(
+ "If data is a list, it must be "
+ "nonempty and contain either "
+ "numbers or dictionaries."
+ )
+
+ if not all(isinstance(element, Number) for element in data):
+ raise exceptions.PlotlyError(
+ "If data is a list, it must contain only numbers."
+ )
+
+ if pd and isinstance(data, pd.core.frame.DataFrame):
+ if data_header is None:
+ raise exceptions.PlotlyError(
+ "data_header must be the "
+ "column name with the "
+ "desired numeric data for "
+ "the violin plot."
+ )
+
+ data = data[data_header].values.tolist()
+
+ # call the plotting functions
+ plot_data, plot_xrange = violinplot(
+ data, fillcolor=valid_colors[0], rugplot=rugplot
+ )
+
+ layout = graph_objs.Layout(
+ title=title,
+ autosize=False,
+ font=graph_objs.layout.Font(size=11),
+ height=height,
+ showlegend=False,
+ width=width,
+ xaxis=make_XAxis("", plot_xrange),
+ yaxis=make_YAxis(""),
+ hovermode="closest",
+ )
+ layout["yaxis"].update(dict(showline=False, showticklabels=False, ticks=""))
+
+ fig = graph_objs.Figure(data=plot_data, layout=layout)
+
+ return fig
+
+ else:
+ if not isinstance(data, pd.core.frame.DataFrame):
+ raise exceptions.PlotlyError(
+ "Error. You must use a pandas "
+ "DataFrame if you are using a "
+ "group header."
+ )
+
+ if data_header is None:
+ raise exceptions.PlotlyError(
+ "data_header must be the column "
+ "name with the desired numeric "
+ "data for the violin plot."
+ )
+
+ if use_colorscale is False:
+ if isinstance(valid_colors, dict):
+ # validate colors dict choice below
+ fig = violin_dict(
+ data,
+ data_header,
+ group_header,
+ valid_colors,
+ use_colorscale,
+ group_stats,
+ rugplot,
+ sort,
+ height,
+ width,
+ title,
+ )
+ return fig
+ else:
+ fig = violin_no_colorscale(
+ data,
+ data_header,
+ group_header,
+ valid_colors,
+ use_colorscale,
+ group_stats,
+ rugplot,
+ sort,
+ height,
+ width,
+ title,
+ )
+ return fig
+ else:
+ if isinstance(valid_colors, dict):
+ raise exceptions.PlotlyError(
+ "The colors param cannot be "
+ "a dictionary if you are "
+ "using a colorscale."
+ )
+
+ if len(valid_colors) < 2:
+ raise exceptions.PlotlyError(
+ "colors must be a list with "
+ "at least 2 colors. A "
+ "Plotly scale is allowed."
+ )
+
+ if not isinstance(group_stats, dict):
+ raise exceptions.PlotlyError(
+ "Your group_stats param must be a dictionary."
+ )
+
+ fig = violin_colorscale(
+ data,
+ data_header,
+ group_header,
+ valid_colors,
+ use_colorscale,
+ group_stats,
+ rugplot,
+ sort,
+ height,
+ width,
+ title,
+ )
+ return fig
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py
new file mode 100644
index 0000000..e20a319
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py
@@ -0,0 +1,249 @@
+from collections.abc import Sequence
+
+from plotly import exceptions
+
+
+def is_sequence(obj):
+ return isinstance(obj, Sequence) and not isinstance(obj, str)
+
+
+def validate_index(index_vals):
+ """
+ Validates if a list contains all numbers or all strings
+
+ :raises: (PlotlyError) If there are any two items in the list whose
+ types differ
+ """
+ from numbers import Number
+
+ if isinstance(index_vals[0], Number):
+ if not all(isinstance(item, Number) for item in index_vals):
+ raise exceptions.PlotlyError(
+ "Error in indexing column. "
+ "Make sure all entries of each "
+ "column are all numbers or "
+ "all strings."
+ )
+
+ elif isinstance(index_vals[0], str):
+ if not all(isinstance(item, str) for item in index_vals):
+ raise exceptions.PlotlyError(
+ "Error in indexing column. "
+ "Make sure all entries of each "
+ "column are all numbers or "
+ "all strings."
+ )
+
+
+def validate_dataframe(array):
+ """
+ Validates all strings or numbers in each dataframe column
+
+ :raises: (PlotlyError) If there are any two items in any list whose
+ types differ
+ """
+ from numbers import Number
+
+ for vector in array:
+ if isinstance(vector[0], Number):
+ if not all(isinstance(item, Number) for item in vector):
+ raise exceptions.PlotlyError(
+ "Error in dataframe. "
+ "Make sure all entries of "
+ "each column are either "
+ "numbers or strings."
+ )
+ elif isinstance(vector[0], str):
+ if not all(isinstance(item, str) for item in vector):
+ raise exceptions.PlotlyError(
+ "Error in dataframe. "
+ "Make sure all entries of "
+ "each column are either "
+ "numbers or strings."
+ )
+
+
+def validate_equal_length(*args):
+ """
+ Validates that data lists or ndarrays are the same length.
+
+ :raises: (PlotlyError) If any data lists are not the same length.
+ """
+ length = len(args[0])
+ if any(len(lst) != length for lst in args):
+ raise exceptions.PlotlyError(
+ "Oops! Your data lists or ndarrays should be the same length."
+ )
+
+
+def validate_positive_scalars(**kwargs):
+ """
+ Validates that all values given in key/val pairs are positive.
+
+ Accepts kwargs to improve Exception messages.
+
+ :raises: (PlotlyError) If any value is < 0 or raises.
+ """
+ for key, val in kwargs.items():
+ try:
+ if val <= 0:
+ raise ValueError("{} must be > 0, got {}".format(key, val))
+ except TypeError:
+ raise exceptions.PlotlyError("{} must be a number, got {}".format(key, val))
+
+
+def flatten(array):
+ """
+ Uses list comprehension to flatten array
+
+ :param (array): An iterable to flatten
+ :raises (PlotlyError): If iterable is not nested.
+ :rtype (list): The flattened list.
+ """
+ try:
+ return [item for sublist in array for item in sublist]
+ except TypeError:
+ raise exceptions.PlotlyError(
+ "Your data array could not be "
+ "flattened! Make sure your data is "
+ "entered as lists or ndarrays!"
+ )
+
+
+def endpts_to_intervals(endpts):
+ """
+ Returns a list of intervals for categorical colormaps
+
+ Accepts a list or tuple of sequentially increasing numbers and returns
+ a list representation of the mathematical intervals with these numbers
+ as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]]
+
+ :raises: (PlotlyError) If input is not a list or tuple
+ :raises: (PlotlyError) If the input contains a string
+ :raises: (PlotlyError) If any number does not increase after the
+ previous one in the sequence
+ """
+ length = len(endpts)
+ # Check if endpts is a list or tuple
+ if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))):
+ raise exceptions.PlotlyError(
+ "The intervals_endpts argument must "
+ "be a list or tuple of a sequence "
+ "of increasing numbers."
+ )
+ # Check if endpts contains only numbers
+ for item in endpts:
+ if isinstance(item, str):
+ raise exceptions.PlotlyError(
+ "The intervals_endpts argument "
+ "must be a list or tuple of a "
+ "sequence of increasing "
+ "numbers."
+ )
+ # Check if numbers in endpts are increasing
+ for k in range(length - 1):
+ if endpts[k] >= endpts[k + 1]:
+ raise exceptions.PlotlyError(
+ "The intervals_endpts argument "
+ "must be a list or tuple of a "
+ "sequence of increasing "
+ "numbers."
+ )
+ else:
+ intervals = []
+ # add -inf to intervals
+ intervals.append([float("-inf"), endpts[0]])
+ for k in range(length - 1):
+ interval = []
+ interval.append(endpts[k])
+ interval.append(endpts[k + 1])
+ intervals.append(interval)
+ # add +inf to intervals
+ intervals.append([endpts[length - 1], float("inf")])
+ return intervals
+
+
+def annotation_dict_for_label(
+ text,
+ lane,
+ num_of_lanes,
+ subplot_spacing,
+ row_col="col",
+ flipped=True,
+ right_side=True,
+ text_color="#0f0f0f",
+):
+ """
+ Returns annotation dict for label of n labels of a 1xn or nx1 subplot.
+
+ :param (str) text: the text for a label.
+ :param (int) lane: the label number for text. From 1 to n inclusive.
+ :param (int) num_of_lanes: the number 'n' of rows or columns in subplot.
+ :param (float) subplot_spacing: the value for the horizontal_spacing and
+ vertical_spacing params in your plotly.tools.make_subplots() call.
+ :param (str) row_col: choose whether labels are placed along rows or
+ columns.
+ :param (bool) flipped: flips text by 90 degrees. Text is printed
+ horizontally if set to True and row_col='row', or if False and
+ row_col='col'.
+ :param (bool) right_side: only applicable if row_col is set to 'row'.
+ :param (str) text_color: color of the text.
+ """
+ temp = (1 - (num_of_lanes - 1) * subplot_spacing) / (num_of_lanes)
+ if not flipped:
+ xanchor = "center"
+ yanchor = "middle"
+ if row_col == "col":
+ x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
+ y = 1.03
+ textangle = 0
+ elif row_col == "row":
+ y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
+ x = 1.03
+ textangle = 90
+ else:
+ if row_col == "col":
+ xanchor = "center"
+ yanchor = "bottom"
+ x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
+ y = 1.0
+ textangle = 270
+ elif row_col == "row":
+ yanchor = "middle"
+ y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp
+ if right_side:
+ x = 1.0
+ xanchor = "left"
+ else:
+ x = -0.01
+ xanchor = "right"
+ textangle = 0
+
+ annotation_dict = dict(
+ textangle=textangle,
+ xanchor=xanchor,
+ yanchor=yanchor,
+ x=x,
+ y=y,
+ showarrow=False,
+ xref="paper",
+ yref="paper",
+ text=text,
+ font=dict(size=13, color=text_color),
+ )
+ return annotation_dict
+
+
+def list_of_options(iterable, conj="and", period=True):
+ """
+ Returns an English listing of objects seperated by commas ','
+
+ For example, ['foo', 'bar', 'baz'] becomes 'foo, bar and baz'
+ if the conjunction 'and' is selected.
+ """
+ if len(iterable) < 2:
+ raise exceptions.PlotlyError(
+ "Your list or tuple must contain at least 2 items."
+ )
+ template = (len(iterable) - 2) * "{}, " + "{} " + conj + " {}" + period * "."
+ return template.format(*iterable)