diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/figure_factory')
20 files changed, 10313 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py new file mode 100644 index 0000000..3094d0b --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_2d_density.py @@ -0,0 +1,155 @@ +from numbers import Number + +import plotly.exceptions + +import plotly.colors as clrs +from plotly.graph_objs import graph_objs + + +def make_linear_colorscale(colors): + """ + Makes a list of colors into a colorscale-acceptable form + + For documentation regarding to the form of the output, see + https://plot.ly/python/reference/#mesh3d-colorscale + """ + scale = 1.0 / (len(colors) - 1) + return [[i * scale, color] for i, color in enumerate(colors)] + + +def create_2d_density( + x, + y, + colorscale="Earth", + ncontours=20, + hist_color=(0, 0, 0.5), + point_color=(0, 0, 0.5), + point_size=2, + title="2D Density Plot", + height=600, + width=600, +): + """ + **deprecated**, use instead + :func:`plotly.express.density_heatmap`. + + :param (list|array) x: x-axis data for plot generation + :param (list|array) y: y-axis data for plot generation + :param (str|tuple|list) colorscale: either a plotly scale name, an rgb + or hex color, a color tuple or a list or tuple of colors. An rgb + color is of the form 'rgb(x, y, z)' where x, y, z belong to the + interval [0, 255] and a color tuple is a tuple of the form + (a, b, c) where a, b and c belong to [0, 1]. If colormap is a + list, it must contain the valid color types aforementioned as its + members. + :param (int) ncontours: the number of 2D contours to draw on the plot + :param (str) hist_color: the color of the plotted histograms + :param (str) point_color: the color of the scatter points + :param (str) point_size: the color of the scatter points + :param (str) title: set the title for the plot + :param (float) height: the height of the chart + :param (float) width: the width of the chart + + Examples + -------- + + Example 1: Simple 2D Density Plot + + >>> from plotly.figure_factory import create_2d_density + >>> import numpy as np + + >>> # Make data points + >>> t = np.linspace(-1,1.2,2000) + >>> x = (t**3)+(0.3*np.random.randn(2000)) + >>> y = (t**6)+(0.3*np.random.randn(2000)) + + >>> # Create a figure + >>> fig = create_2d_density(x, y) + + >>> # Plot the data + >>> fig.show() + + Example 2: Using Parameters + + >>> from plotly.figure_factory import create_2d_density + + >>> import numpy as np + + >>> # Make data points + >>> t = np.linspace(-1,1.2,2000) + >>> x = (t**3)+(0.3*np.random.randn(2000)) + >>> y = (t**6)+(0.3*np.random.randn(2000)) + + >>> # Create custom colorscale + >>> colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', + ... (1, 1, 0.2), (0.98,0.98,0.98)] + + >>> # Create a figure + >>> fig = create_2d_density(x, y, colorscale=colorscale, + ... hist_color='rgb(255, 237, 222)', point_size=3) + + >>> # Plot the data + >>> fig.show() + """ + + # validate x and y are filled with numbers only + for array in [x, y]: + if not all(isinstance(element, Number) for element in array): + raise plotly.exceptions.PlotlyError( + "All elements of your 'x' and 'y' lists must be numbers." + ) + + # validate x and y are the same length + if len(x) != len(y): + raise plotly.exceptions.PlotlyError( + "Both lists 'x' and 'y' must be the same length." + ) + + colorscale = clrs.validate_colors(colorscale, "rgb") + colorscale = make_linear_colorscale(colorscale) + + # validate hist_color and point_color + hist_color = clrs.validate_colors(hist_color, "rgb") + point_color = clrs.validate_colors(point_color, "rgb") + + trace1 = graph_objs.Scatter( + x=x, + y=y, + mode="markers", + name="points", + marker=dict(color=point_color[0], size=point_size, opacity=0.4), + ) + trace2 = graph_objs.Histogram2dContour( + x=x, + y=y, + name="density", + ncontours=ncontours, + colorscale=colorscale, + reversescale=True, + showscale=False, + ) + trace3 = graph_objs.Histogram( + x=x, name="x density", marker=dict(color=hist_color[0]), yaxis="y2" + ) + trace4 = graph_objs.Histogram( + y=y, name="y density", marker=dict(color=hist_color[0]), xaxis="x2" + ) + data = [trace1, trace2, trace3, trace4] + + layout = graph_objs.Layout( + showlegend=False, + autosize=False, + title=title, + height=height, + width=width, + xaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False), + yaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False), + margin=dict(t=50), + hovermode="closest", + bargap=0, + xaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False), + yaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False), + ) + + fig = graph_objs.Figure(data=data, layout=layout) + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py new file mode 100644 index 0000000..1919ca8 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/__init__.py @@ -0,0 +1,69 @@ +# ruff: noqa: E402 + +from plotly import optional_imports + +# Require that numpy exists for figure_factory +np = optional_imports.get_module("numpy") +if np is None: + raise ImportError( + """\ +The figure factory module requires the numpy package""" + ) + + +from plotly.figure_factory._2d_density import create_2d_density +from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap +from plotly.figure_factory._bullet import create_bullet +from plotly.figure_factory._candlestick import create_candlestick +from plotly.figure_factory._dendrogram import create_dendrogram +from plotly.figure_factory._distplot import create_distplot +from plotly.figure_factory._facet_grid import create_facet_grid +from plotly.figure_factory._gantt import create_gantt +from plotly.figure_factory._ohlc import create_ohlc +from plotly.figure_factory._quiver import create_quiver +from plotly.figure_factory._scatterplot import create_scatterplotmatrix +from plotly.figure_factory._streamline import create_streamline +from plotly.figure_factory._table import create_table +from plotly.figure_factory._trisurf import create_trisurf +from plotly.figure_factory._violin import create_violin + +if optional_imports.get_module("pandas") is not None: + from plotly.figure_factory._county_choropleth import create_choropleth + from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox +else: + + def create_choropleth(*args, **kwargs): + raise ImportError("Please install pandas to use `create_choropleth`") + + def create_hexbin_mapbox(*args, **kwargs): + raise ImportError("Please install pandas to use `create_hexbin_mapbox`") + + +if optional_imports.get_module("skimage") is not None: + from plotly.figure_factory._ternary_contour import create_ternary_contour +else: + + def create_ternary_contour(*args, **kwargs): + raise ImportError("Please install scikit-image to use `create_ternary_contour`") + + +__all__ = [ + "create_2d_density", + "create_annotated_heatmap", + "create_bullet", + "create_candlestick", + "create_choropleth", + "create_dendrogram", + "create_distplot", + "create_facet_grid", + "create_gantt", + "create_hexbin_mapbox", + "create_ohlc", + "create_quiver", + "create_scatterplotmatrix", + "create_streamline", + "create_table", + "create_ternary_contour", + "create_trisurf", + "create_violin", +] diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py new file mode 100644 index 0000000..5da24ae --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_annotated_heatmap.py @@ -0,0 +1,307 @@ +import plotly.colors as clrs +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.validator_cache import ValidatorCache + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module("numpy") + + +def validate_annotated_heatmap(z, x, y, annotation_text): + """ + Annotated-heatmap-specific validations + + Check that if a text matrix is supplied, it has the same + dimensions as the z matrix. + + See FigureFactory.create_annotated_heatmap() for params + + :raises: (PlotlyError) If z and text matrices do not have the same + dimensions. + """ + if annotation_text is not None and isinstance(annotation_text, list): + utils.validate_equal_length(z, annotation_text) + for lst in range(len(z)): + if len(z[lst]) != len(annotation_text[lst]): + raise exceptions.PlotlyError( + "z and text should have the same dimensions" + ) + + if x: + if len(x) != len(z[0]): + raise exceptions.PlotlyError( + "oops, the x list that you " + "provided does not match the " + "width of your z matrix " + ) + + if y: + if len(y) != len(z): + raise exceptions.PlotlyError( + "oops, the y list that you " + "provided does not match the " + "length of your z matrix " + ) + + +def create_annotated_heatmap( + z, + x=None, + y=None, + annotation_text=None, + colorscale="Plasma", + font_colors=None, + showscale=False, + reversescale=False, + **kwargs, +): + """ + **deprecated**, use instead + :func:`plotly.express.imshow`. + + Function that creates annotated heatmaps + + This function adds annotations to each cell of the heatmap. + + :param (list[list]|ndarray) z: z matrix to create heatmap. + :param (list) x: x axis labels. + :param (list) y: y axis labels. + :param (list[list]|ndarray) annotation_text: Text strings for + annotations. Should have the same dimensions as the z matrix. If no + text is added, the values of the z matrix are annotated. Default = + z matrix values. + :param (list|str) colorscale: heatmap colorscale. + :param (list) font_colors: List of two color strings: [min_text_color, + max_text_color] where min_text_color is applied to annotations for + heatmap values < (max_value - min_value)/2. If font_colors is not + defined, the colors are defined logically as black or white + depending on the heatmap's colorscale. + :param (bool) showscale: Display colorscale. Default = False + :param (bool) reversescale: Reverse colorscale. Default = False + :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. + These kwargs describe other attributes about the annotated Heatmap + trace such as the colorscale. For more information on valid kwargs + call help(plotly.graph_objs.Heatmap) + + Example 1: Simple annotated heatmap with default configuration + + >>> import plotly.figure_factory as ff + + >>> z = [[0.300000, 0.00000, 0.65, 0.300000], + ... [1, 0.100005, 0.45, 0.4300], + ... [0.300000, 0.00000, 0.65, 0.300000], + ... [1, 0.100005, 0.45, 0.00000]] + + >>> fig = ff.create_annotated_heatmap(z) + >>> fig.show() + """ + + # Avoiding mutables in the call signature + font_colors = font_colors if font_colors is not None else [] + validate_annotated_heatmap(z, x, y, annotation_text) + + # validate colorscale + colorscale_validator = ValidatorCache.get_validator("heatmap", "colorscale") + colorscale = colorscale_validator.validate_coerce(colorscale) + + annotations = _AnnotatedHeatmap( + z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs + ).make_annotations() + + if x or y: + trace = dict( + type="heatmap", + z=z, + x=x, + y=y, + colorscale=colorscale, + showscale=showscale, + reversescale=reversescale, + **kwargs, + ) + layout = dict( + annotations=annotations, + xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"), + yaxis=dict(ticks="", dtick=1, ticksuffix=" "), + ) + else: + trace = dict( + type="heatmap", + z=z, + colorscale=colorscale, + showscale=showscale, + reversescale=reversescale, + **kwargs, + ) + layout = dict( + annotations=annotations, + xaxis=dict( + ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False + ), + yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False), + ) + + data = [trace] + + return graph_objs.Figure(data=data, layout=layout) + + +def to_rgb_color_list(color_str, default): + color_str = color_str.strip() + if color_str.startswith("rgb"): + return [int(v) for v in color_str.strip("rgba()").split(",")] + elif color_str.startswith("#"): + return clrs.hex_to_rgb(color_str) + else: + return default + + +def should_use_black_text(background_color): + return ( + background_color[0] * 0.299 + + background_color[1] * 0.587 + + background_color[2] * 0.114 + ) > 186 + + +class _AnnotatedHeatmap(object): + """ + Refer to TraceFactory.create_annotated_heatmap() for docstring + """ + + def __init__( + self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs + ): + self.z = z + if x: + self.x = x + else: + self.x = range(len(z[0])) + if y: + self.y = y + else: + self.y = range(len(z)) + if annotation_text is not None: + self.annotation_text = annotation_text + else: + self.annotation_text = self.z + self.colorscale = colorscale + self.reversescale = reversescale + self.font_colors = font_colors + + if np and isinstance(self.z, np.ndarray): + self.zmin = np.amin(self.z) + self.zmax = np.amax(self.z) + else: + self.zmin = min([v for row in self.z for v in row]) + self.zmax = max([v for row in self.z for v in row]) + + if kwargs.get("zmin", None) is not None: + self.zmin = kwargs["zmin"] + if kwargs.get("zmax", None) is not None: + self.zmax = kwargs["zmax"] + + self.zmid = (self.zmax + self.zmin) / 2 + + if kwargs.get("zmid", None) is not None: + self.zmid = kwargs["zmid"] + + def get_text_color(self): + """ + Get font color for annotations. + + The annotated heatmap can feature two text colors: min_text_color and + max_text_color. The min_text_color is applied to annotations for + heatmap values < (max_value - min_value)/2. The user can define these + two colors. Otherwise the colors are defined logically as black or + white depending on the heatmap's colorscale. + + :rtype (string, string) min_text_color, max_text_color: text + color for annotations for heatmap values < + (max_value - min_value)/2 and text color for annotations for + heatmap values >= (max_value - min_value)/2 + """ + # Plotly colorscales ranging from a lighter shade to a darker shade + colorscales = [ + "Greys", + "Greens", + "Blues", + "YIGnBu", + "YIOrRd", + "RdBu", + "Picnic", + "Jet", + "Hot", + "Blackbody", + "Earth", + "Electric", + "Viridis", + "Cividis", + ] + # Plotly colorscales ranging from a darker shade to a lighter shade + colorscales_reverse = ["Reds"] + + white = "#FFFFFF" + black = "#000000" + if self.font_colors: + min_text_color = self.font_colors[0] + max_text_color = self.font_colors[-1] + elif self.colorscale in colorscales and self.reversescale: + min_text_color = black + max_text_color = white + elif self.colorscale in colorscales: + min_text_color = white + max_text_color = black + elif self.colorscale in colorscales_reverse and self.reversescale: + min_text_color = white + max_text_color = black + elif self.colorscale in colorscales_reverse: + min_text_color = black + max_text_color = white + elif isinstance(self.colorscale, list): + min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255]) + max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255]) + + # swap min/max colors if reverse scale + if self.reversescale: + min_col, max_col = max_col, min_col + + if should_use_black_text(min_col): + min_text_color = black + else: + min_text_color = white + + if should_use_black_text(max_col): + max_text_color = black + else: + max_text_color = white + else: + min_text_color = black + max_text_color = black + return min_text_color, max_text_color + + def make_annotations(self): + """ + Get annotations for each cell of the heatmap with graph_objs.Annotation + + :rtype (list[dict]) annotations: list of annotations for each cell of + the heatmap + """ + min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) + annotations = [] + for n, row in enumerate(self.z): + for m, val in enumerate(row): + font_color = min_text_color if val < self.zmid else max_text_color + annotations.append( + graph_objs.layout.Annotation( + text=str(self.annotation_text[n][m]), + x=self.x[m], + y=self.y[n], + xref="x1", + yref="y1", + font=dict(color=font_color), + showarrow=False, + ) + ) + return annotations diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py new file mode 100644 index 0000000..ce51e93 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_bullet.py @@ -0,0 +1,366 @@ +import math + +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.figure_factory import utils + +import plotly +import plotly.graph_objs as go + +pd = optional_imports.get_module("pandas") + + +def _bullet( + df, + markers, + measures, + ranges, + subtitles, + titles, + orientation, + range_colors, + measure_colors, + horizontal_spacing, + vertical_spacing, + scatter_options, + layout_options, +): + num_of_lanes = len(df) + num_of_rows = num_of_lanes if orientation == "h" else 1 + num_of_cols = 1 if orientation == "h" else num_of_lanes + if not horizontal_spacing: + horizontal_spacing = 1.0 / num_of_lanes + if not vertical_spacing: + vertical_spacing = 1.0 / num_of_lanes + fig = plotly.subplots.make_subplots( + num_of_rows, + num_of_cols, + print_grid=False, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + ) + + # layout + fig["layout"].update( + dict(shapes=[]), + title="Bullet Chart", + height=600, + width=1000, + showlegend=False, + barmode="stack", + annotations=[], + margin=dict(l=120 if orientation == "h" else 80), + ) + + # update layout + fig["layout"].update(layout_options) + + if orientation == "h": + width_axis = "yaxis" + length_axis = "xaxis" + else: + width_axis = "xaxis" + length_axis = "yaxis" + + for key in fig["layout"]: + if "xaxis" in key or "yaxis" in key: + fig["layout"][key]["showgrid"] = False + fig["layout"][key]["zeroline"] = False + if length_axis in key: + fig["layout"][key]["tickwidth"] = 1 + if width_axis in key: + fig["layout"][key]["showticklabels"] = False + fig["layout"][key]["range"] = [0, 1] + + # narrow domain if 1 bar + if num_of_lanes <= 1: + fig["layout"][width_axis + "1"]["domain"] = [0.4, 0.6] + + if not range_colors: + range_colors = ["rgb(200, 200, 200)", "rgb(245, 245, 245)"] + if not measure_colors: + measure_colors = ["rgb(31, 119, 180)", "rgb(176, 196, 221)"] + + for row in range(num_of_lanes): + # ranges bars + for idx in range(len(df.iloc[row]["ranges"])): + inter_colors = clrs.n_colors( + range_colors[0], range_colors[1], len(df.iloc[row]["ranges"]), "rgb" + ) + x = ( + [sorted(df.iloc[row]["ranges"])[-1 - idx]] + if orientation == "h" + else [0] + ) + y = ( + [0] + if orientation == "h" + else [sorted(df.iloc[row]["ranges"])[-1 - idx]] + ) + bar = go.Bar( + x=x, + y=y, + marker=dict(color=inter_colors[-1 - idx]), + name="ranges", + hoverinfo="x" if orientation == "h" else "y", + orientation=orientation, + width=2, + base=0, + xaxis="x{}".format(row + 1), + yaxis="y{}".format(row + 1), + ) + fig.add_trace(bar) + + # measures bars + for idx in range(len(df.iloc[row]["measures"])): + inter_colors = clrs.n_colors( + measure_colors[0], + measure_colors[1], + len(df.iloc[row]["measures"]), + "rgb", + ) + x = ( + [sorted(df.iloc[row]["measures"])[-1 - idx]] + if orientation == "h" + else [0.5] + ) + y = ( + [0.5] + if orientation == "h" + else [sorted(df.iloc[row]["measures"])[-1 - idx]] + ) + bar = go.Bar( + x=x, + y=y, + marker=dict(color=inter_colors[-1 - idx]), + name="measures", + hoverinfo="x" if orientation == "h" else "y", + orientation=orientation, + width=0.4, + base=0, + xaxis="x{}".format(row + 1), + yaxis="y{}".format(row + 1), + ) + fig.add_trace(bar) + + # markers + x = df.iloc[row]["markers"] if orientation == "h" else [0.5] + y = [0.5] if orientation == "h" else df.iloc[row]["markers"] + markers = go.Scatter( + x=x, + y=y, + name="markers", + hoverinfo="x" if orientation == "h" else "y", + xaxis="x{}".format(row + 1), + yaxis="y{}".format(row + 1), + **scatter_options, + ) + + fig.add_trace(markers) + + # titles and subtitles + title = df.iloc[row]["titles"] + if "subtitles" in df: + subtitle = "<br>{}".format(df.iloc[row]["subtitles"]) + else: + subtitle = "" + label = "<b>{}</b>".format(title) + subtitle + annot = utils.annotation_dict_for_label( + label, + (num_of_lanes - row if orientation == "h" else row + 1), + num_of_lanes, + vertical_spacing if orientation == "h" else horizontal_spacing, + "row" if orientation == "h" else "col", + True if orientation == "h" else False, + False, + ) + fig["layout"]["annotations"] += (annot,) + + return fig + + +def create_bullet( + data, + markers=None, + measures=None, + ranges=None, + subtitles=None, + titles=None, + orientation="h", + range_colors=("rgb(200, 200, 200)", "rgb(245, 245, 245)"), + measure_colors=("rgb(31, 119, 180)", "rgb(176, 196, 221)"), + horizontal_spacing=None, + vertical_spacing=None, + scatter_options={}, + **layout_options, +): + """ + **deprecated**, use instead the plotly.graph_objects trace + :class:`plotly.graph_objects.Indicator`. + + :param (pd.DataFrame | list | tuple) data: either a list/tuple of + dictionaries or a pandas DataFrame. + :param (str) markers: the column name or dictionary key for the markers in + each subplot. + :param (str) measures: the column name or dictionary key for the measure + bars in each subplot. This bar usually represents the quantitative + measure of performance, usually a list of two values [a, b] and are + the blue bars in the foreground of each subplot by default. + :param (str) ranges: the column name or dictionary key for the qualitative + ranges of performance, usually a 3-item list [bad, okay, good]. They + correspond to the grey bars in the background of each chart. + :param (str) subtitles: the column name or dictionary key for the subtitle + of each subplot chart. The subplots are displayed right underneath + each title. + :param (str) titles: the column name or dictionary key for the main label + of each subplot chart. + :param (bool) orientation: if 'h', the bars are placed horizontally as + rows. If 'v' the bars are placed vertically in the chart. + :param (list) range_colors: a tuple of two colors between which all + the rectangles for the range are drawn. These rectangles are meant to + be qualitative indicators against which the marker and measure bars + are compared. + Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)') + :param (list) measure_colors: a tuple of two colors which is used to color + the thin quantitative bars in the bullet chart. + Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)') + :param (float) horizontal_spacing: see the 'horizontal_spacing' param in + plotly.tools.make_subplots. Ranges between 0 and 1. + :param (float) vertical_spacing: see the 'vertical_spacing' param in + plotly.tools.make_subplots. Ranges between 0 and 1. + :param (dict) scatter_options: describes attributes for the scatter trace + in each subplot such as name and marker size. Call + help(plotly.graph_objs.Scatter) for more information on valid params. + :param layout_options: describes attributes for the layout of the figure + such as title, height and width. Call help(plotly.graph_objs.Layout) + for more information on valid params. + + Example 1: Use a Dictionary + + >>> import plotly.figure_factory as ff + + >>> data = [ + ... {"label": "revenue", "sublabel": "us$, in thousands", + ... "range": [150, 225, 300], "performance": [220,270], "point": [250]}, + ... {"label": "Profit", "sublabel": "%", "range": [20, 25, 30], + ... "performance": [21, 23], "point": [26]}, + ... {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600], + ... "performance": [100,320],"point": [550]}, + ... {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500], + ... "performance": [1000, 1650],"point": [2100]}, + ... {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5], + ... "performance": [3.2, 4.7], "point": [4.4]} + ... ] + + >>> fig = ff.create_bullet( + ... data, titles='label', subtitles='sublabel', markers='point', + ... measures='performance', ranges='range', orientation='h', + ... title='my simple bullet chart' + ... ) + >>> fig.show() + + Example 2: Use a DataFrame with Custom Colors + + >>> import plotly.figure_factory as ff + >>> import pandas as pd + >>> data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json') + + >>> fig = ff.create_bullet( + ... data, titles='title', markers='markers', measures='measures', + ... orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'], + ... scatter_options={'marker': {'symbol': 'circle'}}, width=700) + >>> fig.show() + """ + # validate df + if not pd: + raise ImportError("'pandas' must be installed for this figure factory.") + + if utils.is_sequence(data): + if not all(isinstance(item, dict) for item in data): + raise exceptions.PlotlyError( + "Every entry of the data argument list, tuple, etc must " + "be a dictionary." + ) + + elif not isinstance(data, pd.DataFrame): + raise exceptions.PlotlyError( + "You must input a pandas DataFrame, or a list of dictionaries." + ) + + # make DataFrame from data with correct column headers + col_names = ["titles", "subtitle", "markers", "measures", "ranges"] + if utils.is_sequence(data): + df = pd.DataFrame( + [ + [d[titles] for d in data] if titles else [""] * len(data), + [d[subtitles] for d in data] if subtitles else [""] * len(data), + [d[markers] for d in data] if markers else [[]] * len(data), + [d[measures] for d in data] if measures else [[]] * len(data), + [d[ranges] for d in data] if ranges else [[]] * len(data), + ], + index=col_names, + ) + elif isinstance(data, pd.DataFrame): + df = pd.DataFrame( + [ + data[titles].tolist() if titles else [""] * len(data), + data[subtitles].tolist() if subtitles else [""] * len(data), + data[markers].tolist() if markers else [[]] * len(data), + data[measures].tolist() if measures else [[]] * len(data), + data[ranges].tolist() if ranges else [[]] * len(data), + ], + index=col_names, + ) + df = pd.DataFrame.transpose(df) + + # make sure ranges, measures, 'markers' are not NAN or NONE + for needed_key in ["ranges", "measures", "markers"]: + for idx, r in enumerate(df[needed_key]): + try: + r_is_nan = math.isnan(r) + if r_is_nan or r is None: + df[needed_key][idx] = [] + except TypeError: + pass + + # validate custom colors + for colors_list in [range_colors, measure_colors]: + if colors_list: + if len(colors_list) != 2: + raise exceptions.PlotlyError( + "Both 'range_colors' or 'measure_colors' must be a list " + "of two valid colors." + ) + clrs.validate_colors(colors_list) + colors_list = clrs.convert_colors_to_same_type(colors_list, "rgb")[0] + + # default scatter options + default_scatter = { + "marker": {"size": 12, "symbol": "diamond-tall", "color": "rgb(0, 0, 0)"} + } + + if scatter_options == {}: + scatter_options.update(default_scatter) + else: + # add default options to scatter_options if they are not present + for k in default_scatter["marker"]: + if k not in scatter_options["marker"]: + scatter_options["marker"][k] = default_scatter["marker"][k] + + fig = _bullet( + df, + markers, + measures, + ranges, + subtitles, + titles, + orientation, + range_colors, + measure_colors, + horizontal_spacing, + vertical_spacing, + scatter_options, + layout_options, + ) + + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py new file mode 100644 index 0000000..572ccfe --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_candlestick.py @@ -0,0 +1,277 @@ +from plotly.figure_factory import utils +from plotly.figure_factory._ohlc import ( + _DEFAULT_INCREASING_COLOR, + _DEFAULT_DECREASING_COLOR, + validate_ohlc, +) +from plotly.graph_objs import graph_objs + + +def make_increasing_candle(open, high, low, close, dates, **kwargs): + """ + Makes boxplot trace for increasing candlesticks + + _make_increasing_candle() and _make_decreasing_candle separate the + increasing traces from the decreasing traces so kwargs (such as + color) can be passed separately to increasing or decreasing traces + when direction is set to 'increasing' or 'decreasing' in + FigureFactory.create_candlestick() + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (list) candle_incr_data: list of the box trace for + increasing candlesticks. + """ + increase_x, increase_y = _Candlestick( + open, high, low, close, dates, **kwargs + ).get_candle_increase() + + if "line" in kwargs: + kwargs.setdefault("fillcolor", kwargs["line"]["color"]) + else: + kwargs.setdefault("fillcolor", _DEFAULT_INCREASING_COLOR) + if "name" in kwargs: + kwargs.setdefault("showlegend", True) + else: + kwargs.setdefault("showlegend", False) + kwargs.setdefault("name", "Increasing") + kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR)) + + candle_incr_data = dict( + type="box", + x=increase_x, + y=increase_y, + whiskerwidth=0, + boxpoints=False, + **kwargs, + ) + + return [candle_incr_data] + + +def make_decreasing_candle(open, high, low, close, dates, **kwargs): + """ + Makes boxplot trace for decreasing candlesticks + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to decreasing trace via + plotly.graph_objs.Scatter. + + :rtype (list) candle_decr_data: list of the box trace for + decreasing candlesticks. + """ + + decrease_x, decrease_y = _Candlestick( + open, high, low, close, dates, **kwargs + ).get_candle_decrease() + + if "line" in kwargs: + kwargs.setdefault("fillcolor", kwargs["line"]["color"]) + else: + kwargs.setdefault("fillcolor", _DEFAULT_DECREASING_COLOR) + kwargs.setdefault("showlegend", False) + kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR)) + kwargs.setdefault("name", "Decreasing") + + candle_decr_data = dict( + type="box", + x=decrease_x, + y=decrease_y, + whiskerwidth=0, + boxpoints=False, + **kwargs, + ) + + return [candle_decr_data] + + +def create_candlestick(open, high, low, close, dates=None, direction="both", **kwargs): + """ + **deprecated**, use instead the plotly.graph_objects trace + :class:`plotly.graph_objects.Candlestick` + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param (string) direction: direction can be 'increasing', 'decreasing', + or 'both'. When the direction is 'increasing', the returned figure + consists of all candlesticks where the close value is greater than + the corresponding open value, and when the direction is + 'decreasing', the returned figure consists of all candlesticks + where the close value is less than or equal to the corresponding + open value. When the direction is 'both', both increasing and + decreasing candlesticks are returned. Default: 'both' + :param kwargs: kwargs passed through plotly.graph_objs.Scatter. + These kwargs describe other attributes about the ohlc Scatter trace + such as the color or the legend name. For more information on valid + kwargs call help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of candlestick chart figure. + + Example 1: Simple candlestick chart from a Pandas DataFrame + + >>> from plotly.figure_factory import create_candlestick + >>> from datetime import datetime + >>> import pandas as pd + + >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv') + >>> fig = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], + ... dates=df.index) + >>> fig.show() + + Example 2: Customize the candlestick colors + + >>> from plotly.figure_factory import create_candlestick + >>> from plotly.graph_objs import Line, Marker + >>> from datetime import datetime + + >>> import pandas as pd + >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv') + + >>> # Make increasing candlesticks and customize their color and name + >>> fig_increasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], + ... dates=df.index, + ... direction='increasing', name='AAPL', + ... marker=Marker(color='rgb(150, 200, 250)'), + ... line=Line(color='rgb(150, 200, 250)')) + + >>> # Make decreasing candlesticks and customize their color and name + >>> fig_decreasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], + ... dates=df.index, + ... direction='decreasing', + ... marker=Marker(color='rgb(128, 128, 128)'), + ... line=Line(color='rgb(128, 128, 128)')) + + >>> # Initialize the figure + >>> fig = fig_increasing + + >>> # Add decreasing data with .extend() + >>> fig.add_trace(fig_decreasing['data']) # doctest: +SKIP + >>> fig.show() + + Example 3: Candlestick chart with datetime objects + + >>> from plotly.figure_factory import create_candlestick + + >>> from datetime import datetime + + >>> # Add data + >>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1] + >>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8] + >>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8] + >>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1] + >>> dates = [datetime(year=2013, month=10, day=10), + ... datetime(year=2013, month=11, day=10), + ... datetime(year=2013, month=12, day=10), + ... datetime(year=2014, month=1, day=10), + ... datetime(year=2014, month=2, day=10)] + + >>> # Create ohlc + >>> fig = create_candlestick(open_data, high_data, + ... low_data, close_data, dates=dates) + >>> fig.show() + """ + if dates is not None: + utils.validate_equal_length(open, high, low, close, dates) + else: + utils.validate_equal_length(open, high, low, close) + validate_ohlc(open, high, low, close, direction, **kwargs) + + if direction == "increasing": + candle_incr_data = make_increasing_candle( + open, high, low, close, dates, **kwargs + ) + data = candle_incr_data + elif direction == "decreasing": + candle_decr_data = make_decreasing_candle( + open, high, low, close, dates, **kwargs + ) + data = candle_decr_data + else: + candle_incr_data = make_increasing_candle( + open, high, low, close, dates, **kwargs + ) + candle_decr_data = make_decreasing_candle( + open, high, low, close, dates, **kwargs + ) + data = candle_incr_data + candle_decr_data + + layout = graph_objs.Layout() + return graph_objs.Figure(data=data, layout=layout) + + +class _Candlestick(object): + """ + Refer to FigureFactory.create_candlestick() for docstring. + """ + + def __init__(self, open, high, low, close, dates, **kwargs): + self.open = open + self.high = high + self.low = low + self.close = close + if dates is not None: + self.x = dates + else: + self.x = [x for x in range(len(self.open))] + self.get_candle_increase() + + def get_candle_increase(self): + """ + Separate increasing data from decreasing data. + + The data is increasing when close value > open value + and decreasing when the close value <= open value. + """ + increase_y = [] + increase_x = [] + for index in range(len(self.open)): + if self.close[index] > self.open[index]: + increase_y.append(self.low[index]) + increase_y.append(self.open[index]) + increase_y.append(self.close[index]) + increase_y.append(self.close[index]) + increase_y.append(self.close[index]) + increase_y.append(self.high[index]) + increase_x.append(self.x[index]) + + increase_x = [[x, x, x, x, x, x] for x in increase_x] + increase_x = utils.flatten(increase_x) + + return increase_x, increase_y + + def get_candle_decrease(self): + """ + Separate increasing data from decreasing data. + + The data is increasing when close value > open value + and decreasing when the close value <= open value. + """ + decrease_y = [] + decrease_x = [] + for index in range(len(self.open)): + if self.close[index] <= self.open[index]: + decrease_y.append(self.low[index]) + decrease_y.append(self.open[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.high[index]) + decrease_x.append(self.x[index]) + + decrease_x = [[x, x, x, x, x, x] for x in decrease_x] + decrease_x = utils.flatten(decrease_x) + + return decrease_x, decrease_y diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py new file mode 100644 index 0000000..7b397e7 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_county_choropleth.py @@ -0,0 +1,1013 @@ +import io +import numpy as np +import os +import pandas as pd +import warnings + +from math import log, floor +from numbers import Number + +from plotly import optional_imports +import plotly.colors as clrs +from plotly.figure_factory import utils +from plotly.exceptions import PlotlyError +import plotly.graph_objs as go + +pd.options.mode.chained_assignment = None + +shapely = optional_imports.get_module("shapely") +shapefile = optional_imports.get_module("shapefile") +gp = optional_imports.get_module("geopandas") +_plotly_geo = optional_imports.get_module("_plotly_geo") + + +def _create_us_counties_df(st_to_state_name_dict, state_to_st_dict): + # URLS + abs_dir_path = os.path.realpath(_plotly_geo.__file__) + + abs_plotly_geo_path = os.path.dirname(abs_dir_path) + + abs_package_data_dir_path = os.path.join(abs_plotly_geo_path, "package_data") + + shape_pre2010 = "gz_2010_us_050_00_500k.shp" + shape_pre2010 = os.path.join(abs_package_data_dir_path, shape_pre2010) + + df_shape_pre2010 = gp.read_file(shape_pre2010) + df_shape_pre2010["FIPS"] = df_shape_pre2010["STATE"] + df_shape_pre2010["COUNTY"] + df_shape_pre2010["FIPS"] = pd.to_numeric(df_shape_pre2010["FIPS"]) + + states_path = "cb_2016_us_state_500k.shp" + states_path = os.path.join(abs_package_data_dir_path, states_path) + + df_state = gp.read_file(states_path) + df_state = df_state[["STATEFP", "NAME", "geometry"]] + df_state = df_state.rename(columns={"NAME": "STATE_NAME"}) + + filenames = [ + "cb_2016_us_county_500k.dbf", + "cb_2016_us_county_500k.shp", + "cb_2016_us_county_500k.shx", + ] + + for j in range(len(filenames)): + filenames[j] = os.path.join(abs_package_data_dir_path, filenames[j]) + + dbf = io.open(filenames[0], "rb") + shp = io.open(filenames[1], "rb") + shx = io.open(filenames[2], "rb") + + r = shapefile.Reader(shp=shp, shx=shx, dbf=dbf) + + attributes, geometry = [], [] + field_names = [field[0] for field in r.fields[1:]] + for row in r.shapeRecords(): + geometry.append(shapely.geometry.shape(row.shape.__geo_interface__)) + attributes.append(dict(zip(field_names, row.record))) + + gdf = gp.GeoDataFrame(data=attributes, geometry=geometry) + + gdf["FIPS"] = gdf["STATEFP"] + gdf["COUNTYFP"] + gdf["FIPS"] = pd.to_numeric(gdf["FIPS"]) + + # add missing counties + f = 46113 + singlerow = pd.DataFrame( + [ + [ + st_to_state_name_dict["SD"], + "SD", + df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["geometry"].iloc[0], + df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["FIPS"].iloc[0], + "46", + "Shannon", + ] + ], + columns=["State", "ST", "geometry", "FIPS", "STATEFP", "NAME"], + index=[max(gdf.index) + 1], + ) + gdf = pd.concat([gdf, singlerow], sort=True) + + f = 51515 + singlerow = pd.DataFrame( + [ + [ + st_to_state_name_dict["VA"], + "VA", + df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["geometry"].iloc[0], + df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["FIPS"].iloc[0], + "51", + "Bedford City", + ] + ], + columns=["State", "ST", "geometry", "FIPS", "STATEFP", "NAME"], + index=[max(gdf.index) + 1], + ) + gdf = pd.concat([gdf, singlerow], sort=True) + + f = 2270 + singlerow = pd.DataFrame( + [ + [ + st_to_state_name_dict["AK"], + "AK", + df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["geometry"].iloc[0], + df_shape_pre2010[df_shape_pre2010["FIPS"] == f]["FIPS"].iloc[0], + "02", + "Wade Hampton", + ] + ], + columns=["State", "ST", "geometry", "FIPS", "STATEFP", "NAME"], + index=[max(gdf.index) + 1], + ) + gdf = pd.concat([gdf, singlerow], sort=True) + + row_2198 = gdf[gdf["FIPS"] == 2198] + row_2198.index = [max(gdf.index) + 1] + row_2198.loc[row_2198.index[0], "FIPS"] = 2201 + row_2198.loc[row_2198.index[0], "STATEFP"] = "02" + gdf = pd.concat([gdf, row_2198], sort=True) + + row_2105 = gdf[gdf["FIPS"] == 2105] + row_2105.index = [max(gdf.index) + 1] + row_2105.loc[row_2105.index[0], "FIPS"] = 2232 + row_2105.loc[row_2105.index[0], "STATEFP"] = "02" + gdf = pd.concat([gdf, row_2105], sort=True) + gdf = gdf.rename(columns={"NAME": "COUNTY_NAME"}) + + gdf_reduced = gdf[["FIPS", "STATEFP", "COUNTY_NAME", "geometry"]] + gdf_statefp = gdf_reduced.merge(df_state[["STATEFP", "STATE_NAME"]], on="STATEFP") + + ST = [] + for n in gdf_statefp["STATE_NAME"]: + ST.append(state_to_st_dict[n]) + + gdf_statefp["ST"] = ST + return gdf_statefp, df_state + + +st_to_state_name_dict = { + "AK": "Alaska", + "AL": "Alabama", + "AR": "Arkansas", + "AZ": "Arizona", + "CA": "California", + "CO": "Colorado", + "CT": "Connecticut", + "DC": "District of Columbia", + "DE": "Delaware", + "FL": "Florida", + "GA": "Georgia", + "HI": "Hawaii", + "IA": "Iowa", + "ID": "Idaho", + "IL": "Illinois", + "IN": "Indiana", + "KS": "Kansas", + "KY": "Kentucky", + "LA": "Louisiana", + "MA": "Massachusetts", + "MD": "Maryland", + "ME": "Maine", + "MI": "Michigan", + "MN": "Minnesota", + "MO": "Missouri", + "MS": "Mississippi", + "MT": "Montana", + "NC": "North Carolina", + "ND": "North Dakota", + "NE": "Nebraska", + "NH": "New Hampshire", + "NJ": "New Jersey", + "NM": "New Mexico", + "NV": "Nevada", + "NY": "New York", + "OH": "Ohio", + "OK": "Oklahoma", + "OR": "Oregon", + "PA": "Pennsylvania", + "RI": "Rhode Island", + "SC": "South Carolina", + "SD": "South Dakota", + "TN": "Tennessee", + "TX": "Texas", + "UT": "Utah", + "VA": "Virginia", + "VT": "Vermont", + "WA": "Washington", + "WI": "Wisconsin", + "WV": "West Virginia", + "WY": "Wyoming", +} + +state_to_st_dict = { + "Alabama": "AL", + "Alaska": "AK", + "American Samoa": "AS", + "Arizona": "AZ", + "Arkansas": "AR", + "California": "CA", + "Colorado": "CO", + "Commonwealth of the Northern Mariana Islands": "MP", + "Connecticut": "CT", + "Delaware": "DE", + "District of Columbia": "DC", + "Florida": "FL", + "Georgia": "GA", + "Guam": "GU", + "Hawaii": "HI", + "Idaho": "ID", + "Illinois": "IL", + "Indiana": "IN", + "Iowa": "IA", + "Kansas": "KS", + "Kentucky": "KY", + "Louisiana": "LA", + "Maine": "ME", + "Maryland": "MD", + "Massachusetts": "MA", + "Michigan": "MI", + "Minnesota": "MN", + "Mississippi": "MS", + "Missouri": "MO", + "Montana": "MT", + "Nebraska": "NE", + "Nevada": "NV", + "New Hampshire": "NH", + "New Jersey": "NJ", + "New Mexico": "NM", + "New York": "NY", + "North Carolina": "NC", + "North Dakota": "ND", + "Ohio": "OH", + "Oklahoma": "OK", + "Oregon": "OR", + "Pennsylvania": "PA", + "Puerto Rico": "", + "Rhode Island": "RI", + "South Carolina": "SC", + "South Dakota": "SD", + "Tennessee": "TN", + "Texas": "TX", + "United States Virgin Islands": "VI", + "Utah": "UT", + "Vermont": "VT", + "Virginia": "VA", + "Washington": "WA", + "West Virginia": "WV", + "Wisconsin": "WI", + "Wyoming": "WY", +} + +USA_XRANGE = [-125.0, -65.0] +USA_YRANGE = [25.0, 49.0] + + +def _human_format(number): + units = ["", "K", "M", "G", "T", "P"] + k = 1000.0 + magnitude = int(floor(log(number, k))) + return "%.2f%s" % (number / k**magnitude, units[magnitude]) + + +def _intervals_as_labels(array_of_intervals, round_legend_values, exponent_format): + """ + Transform an number interval to a clean string for legend + + Example: [-inf, 30] to '< 30' + """ + infs = [float("-inf"), float("inf")] + string_intervals = [] + for interval in array_of_intervals: + # round to 2nd decimal place + if round_legend_values: + rnd_interval = [ + (int(interval[i]) if interval[i] not in infs else interval[i]) + for i in range(2) + ] + else: + rnd_interval = [round(interval[0], 2), round(interval[1], 2)] + + num0 = rnd_interval[0] + num1 = rnd_interval[1] + if exponent_format: + if num0 not in infs: + num0 = _human_format(num0) + if num1 not in infs: + num1 = _human_format(num1) + else: + if num0 not in infs: + num0 = "{:,}".format(num0) + if num1 not in infs: + num1 = "{:,}".format(num1) + + if num0 == float("-inf"): + as_str = "< {}".format(num1) + elif num1 == float("inf"): + as_str = "> {}".format(num0) + else: + as_str = "{} - {}".format(num0, num1) + string_intervals.append(as_str) + return string_intervals + + +def _calculations( + df, + fips, + values, + index, + f, + simplify_county, + level, + x_centroids, + y_centroids, + centroid_text, + x_traces, + y_traces, + fips_polygon_map, +): + # 0-pad FIPS code to ensure exactly 5 digits + padded_f = str(f).zfill(5) + if fips_polygon_map[f].type == "Polygon": + x = fips_polygon_map[f].simplify(simplify_county).exterior.xy[0].tolist() + y = fips_polygon_map[f].simplify(simplify_county).exterior.xy[1].tolist() + + x_c, y_c = fips_polygon_map[f].centroid.xy + county_name_str = str(df[df["FIPS"] == f]["COUNTY_NAME"].iloc[0]) + state_name_str = str(df[df["FIPS"] == f]["STATE_NAME"].iloc[0]) + + t_c = ( + "County: " + + county_name_str + + "<br>" + + "State: " + + state_name_str + + "<br>" + + "FIPS: " + + padded_f + + "<br>Value: " + + str(values[index]) + ) + + x_centroids.append(x_c[0]) + y_centroids.append(y_c[0]) + centroid_text.append(t_c) + + x_traces[level] = x_traces[level] + x + [np.nan] + y_traces[level] = y_traces[level] + y + [np.nan] + elif fips_polygon_map[f].type == "MultiPolygon": + x = [ + poly.simplify(simplify_county).exterior.xy[0].tolist() + for poly in fips_polygon_map[f].geoms + ] + y = [ + poly.simplify(simplify_county).exterior.xy[1].tolist() + for poly in fips_polygon_map[f].geoms + ] + + x_c = [poly.centroid.xy[0].tolist() for poly in fips_polygon_map[f].geoms] + y_c = [poly.centroid.xy[1].tolist() for poly in fips_polygon_map[f].geoms] + + county_name_str = str(df[df["FIPS"] == f]["COUNTY_NAME"].iloc[0]) + state_name_str = str(df[df["FIPS"] == f]["STATE_NAME"].iloc[0]) + text = ( + "County: " + + county_name_str + + "<br>" + + "State: " + + state_name_str + + "<br>" + + "FIPS: " + + padded_f + + "<br>Value: " + + str(values[index]) + ) + t_c = [text for poly in fips_polygon_map[f].geoms] + x_centroids = x_c + x_centroids + y_centroids = y_c + y_centroids + centroid_text = t_c + centroid_text + for x_y_idx in range(len(x)): + x_traces[level] = x_traces[level] + x[x_y_idx] + [np.nan] + y_traces[level] = y_traces[level] + y[x_y_idx] + [np.nan] + + return x_traces, y_traces, x_centroids, y_centroids, centroid_text + + +def create_choropleth( + fips, + values, + scope=["usa"], + binning_endpoints=None, + colorscale=None, + order=None, + simplify_county=0.02, + simplify_state=0.02, + asp=None, + show_hover=True, + show_state_data=True, + state_outline=None, + county_outline=None, + centroid_marker=None, + round_legend_values=False, + exponent_format=False, + legend_title="", + **layout_options, +): + """ + **deprecated**, use instead + :func:`plotly.express.choropleth` with custom GeoJSON. + + This function also requires `shapely`, `geopandas` and `plotly-geo` to be installed. + + Returns figure for county choropleth. Uses data from package_data. + + :param (list) fips: list of FIPS values which correspond to the con + catination of state and county ids. An example is '01001'. + :param (list) values: list of numbers/strings which correspond to the + fips list. These are the values that will determine how the counties + are colored. + :param (list) scope: list of states and/or states abbreviations. Fits + all states in the camera tightly. Selecting ['usa'] is the equivalent + of appending all 50 states into your scope list. Selecting only 'usa' + does not include 'Alaska', 'Puerto Rico', 'American Samoa', + 'Commonwealth of the Northern Mariana Islands', 'Guam', + 'United States Virgin Islands'. These must be added manually to the + list. + Default = ['usa'] + :param (list) binning_endpoints: ascending numbers which implicitly define + real number intervals which are used as bins. The colorscale used must + have the same number of colors as the number of bins and this will + result in a categorical colormap. + :param (list) colorscale: a list of colors with length equal to the + number of categories of colors. The length must match either all + unique numbers in the 'values' list or if endpoints is being used, the + number of categories created by the endpoints.\n + For example, if binning_endpoints = [4, 6, 8], then there are 4 bins: + [-inf, 4), [4, 6), [6, 8), [8, inf) + :param (list) order: a list of the unique categories (numbers/bins) in any + desired order. This is helpful if you want to order string values to + a chosen colorscale. + :param (float) simplify_county: determines the simplification factor + for the counties. The larger the number, the fewer vertices and edges + each polygon has. See + http://toblerity.org/shapely/manual.html#object.simplify for more + information. + Default = 0.02 + :param (float) simplify_state: simplifies the state outline polygon. + See http://toblerity.org/shapely/manual.html#object.simplify for more + information. + Default = 0.02 + :param (float) asp: the width-to-height aspect ratio for the camera. + Default = 2.5 + :param (bool) show_hover: show county hover and centroid info + :param (bool) show_state_data: reveals state boundary lines + :param (dict) state_outline: dict of attributes of the state outline + including width and color. See + https://plot.ly/python/reference/#scatter-marker-line for all valid + params + :param (dict) county_outline: dict of attributes of the county outline + including width and color. See + https://plot.ly/python/reference/#scatter-marker-line for all valid + params + :param (dict) centroid_marker: dict of attributes of the centroid marker. + The centroid markers are invisible by default and appear visible on + selection. See https://plot.ly/python/reference/#scatter-marker for + all valid params + :param (bool) round_legend_values: automatically round the numbers that + appear in the legend to the nearest integer. + Default = False + :param (bool) exponent_format: if set to True, puts numbers in the K, M, + B number format. For example 4000.0 becomes 4.0K + Default = False + :param (str) legend_title: title that appears above the legend + :param **layout_options: a **kwargs argument for all layout parameters + + + Example 1: Florida:: + + import plotly.plotly as py + import plotly.figure_factory as ff + + import numpy as np + import pandas as pd + + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv' + ) + df_sample_r = df_sample[df_sample['STNAME'] == 'Florida'] + + values = df_sample_r['TOT_POP'].tolist() + fips = df_sample_r['FIPS'].tolist() + + binning_endpoints = list(np.mgrid[min(values):max(values):4j]) + colorscale = ["#030512","#1d1d3b","#323268","#3d4b94","#3e6ab0", + "#4989bc","#60a7c7","#85c5d3","#b7e0e4","#eafcfd"] + fig = ff.create_choropleth( + fips=fips, values=values, scope=['Florida'], show_state_data=True, + colorscale=colorscale, binning_endpoints=binning_endpoints, + round_legend_values=True, plot_bgcolor='rgb(229,229,229)', + paper_bgcolor='rgb(229,229,229)', legend_title='Florida Population', + county_outline={'color': 'rgb(255,255,255)', 'width': 0.5}, + exponent_format=True, + ) + + Example 2: New England:: + + import plotly.figure_factory as ff + + import pandas as pd + + NE_states = ['Connecticut', 'Maine', 'Massachusetts', + 'New Hampshire', 'Rhode Island'] + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv' + ) + df_sample_r = df_sample[df_sample['STNAME'].isin(NE_states)] + colorscale = ['rgb(68.0, 1.0, 84.0)', + 'rgb(66.0, 64.0, 134.0)', + 'rgb(38.0, 130.0, 142.0)', + 'rgb(63.0, 188.0, 115.0)', + 'rgb(216.0, 226.0, 25.0)'] + + values = df_sample_r['TOT_POP'].tolist() + fips = df_sample_r['FIPS'].tolist() + fig = ff.create_choropleth( + fips=fips, values=values, scope=NE_states, show_state_data=True + ) + fig.show() + + Example 3: California and Surrounding States:: + + import plotly.figure_factory as ff + + import pandas as pd + + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv' + ) + df_sample_r = df_sample[df_sample['STNAME'] == 'California'] + + values = df_sample_r['TOT_POP'].tolist() + fips = df_sample_r['FIPS'].tolist() + + colorscale = [ + 'rgb(193, 193, 193)', + 'rgb(239,239,239)', + 'rgb(195, 196, 222)', + 'rgb(144,148,194)', + 'rgb(101,104,168)', + 'rgb(65, 53, 132)' + ] + + fig = ff.create_choropleth( + fips=fips, values=values, colorscale=colorscale, + scope=['CA', 'AZ', 'Nevada', 'Oregon', ' Idaho'], + binning_endpoints=[14348, 63983, 134827, 426762, 2081313], + county_outline={'color': 'rgb(255,255,255)', 'width': 0.5}, + legend_title='California Counties', + title='California and Nearby States' + ) + fig.show() + + Example 4: USA:: + + import plotly.figure_factory as ff + + import numpy as np + import pandas as pd + + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/laucnty16.csv' + ) + df_sample['State FIPS Code'] = df_sample['State FIPS Code'].apply( + lambda x: str(x).zfill(2) + ) + df_sample['County FIPS Code'] = df_sample['County FIPS Code'].apply( + lambda x: str(x).zfill(3) + ) + df_sample['FIPS'] = ( + df_sample['State FIPS Code'] + df_sample['County FIPS Code'] + ) + + binning_endpoints = list(np.linspace(1, 12, len(colorscale) - 1)) + colorscale = ["#f7fbff", "#ebf3fb", "#deebf7", "#d2e3f3", "#c6dbef", + "#b3d2e9", "#9ecae1", "#85bcdb", "#6baed6", "#57a0ce", + "#4292c6", "#3082be", "#2171b5", "#1361a9", "#08519c", + "#0b4083","#08306b"] + fips = df_sample['FIPS'] + values = df_sample['Unemployment Rate (%)'] + fig = ff.create_choropleth( + fips=fips, values=values, scope=['usa'], + binning_endpoints=binning_endpoints, colorscale=colorscale, + show_hover=True, centroid_marker={'opacity': 0}, + asp=2.9, title='USA by Unemployment %', + legend_title='Unemployment %' + ) + fig.show() + """ + # ensure optional modules imported + if not _plotly_geo: + raise ValueError( + """ +The create_choropleth figure factory requires the plotly-geo package. +Install using pip with: + +$ pip install plotly-geo + +Or, install using conda with + +$ conda install -c plotly plotly-geo +""" + ) + + if not gp or not shapefile or not shapely: + raise ImportError( + "geopandas, pyshp and shapely must be installed for this figure " + "factory.\n\nRun the following commands to install the correct " + "versions of the following modules:\n\n" + "```\n" + "$ pip install geopandas==0.3.0\n" + "$ pip install pyshp==1.2.10\n" + "$ pip install shapely==1.6.3\n" + "```\n" + "If you are using Windows, follow this post to properly " + "install geopandas and dependencies:" + "http://geoffboeing.com/2014/09/using-geopandas-windows/\n\n" + "If you are using Anaconda, do not use PIP to install the " + "packages above. Instead use conda to install them:\n\n" + "```\n" + "$ conda install plotly\n" + "$ conda install geopandas\n" + "```" + ) + + df, df_state = _create_us_counties_df(st_to_state_name_dict, state_to_st_dict) + + fips_polygon_map = dict(zip(df["FIPS"].tolist(), df["geometry"].tolist())) + + if not state_outline: + state_outline = {"color": "rgb(240, 240, 240)", "width": 1} + if not county_outline: + county_outline = {"color": "rgb(0, 0, 0)", "width": 0} + if not centroid_marker: + centroid_marker = {"size": 3, "color": "white", "opacity": 1} + + # ensure centroid markers appear on selection + if "opacity" not in centroid_marker: + centroid_marker.update({"opacity": 1}) + + if len(fips) != len(values): + raise PlotlyError("fips and values must be the same length") + + # make fips, values into lists + if isinstance(fips, pd.core.series.Series): + fips = fips.tolist() + if isinstance(values, pd.core.series.Series): + values = values.tolist() + + # make fips numeric + fips = map(lambda x: int(x), fips) + + if binning_endpoints: + intervals = utils.endpts_to_intervals(binning_endpoints) + LEVELS = _intervals_as_labels(intervals, round_legend_values, exponent_format) + else: + if not order: + LEVELS = sorted(list(set(values))) + else: + # check if order is permutation + # of unique color col values + same_sets = sorted(list(set(values))) == set(order) + no_duplicates = not any(order.count(x) > 1 for x in order) + if same_sets and no_duplicates: + LEVELS = order + else: + raise PlotlyError( + "if you are using a custom order of unique values from " + "your color column, you must: have all the unique values " + "in your order and have no duplicate items" + ) + + if not colorscale: + colorscale = [] + viridis_colors = clrs.colorscale_to_colors(clrs.PLOTLY_SCALES["Viridis"]) + viridis_colors = clrs.color_parser(viridis_colors, clrs.hex_to_rgb) + viridis_colors = clrs.color_parser(viridis_colors, clrs.label_rgb) + viri_len = len(viridis_colors) + 1 + viri_intervals = utils.endpts_to_intervals(list(np.linspace(0, 1, viri_len)))[ + 1:-1 + ] + + for L in np.linspace(0, 1, len(LEVELS)): + for idx, inter in enumerate(viri_intervals): + if L == 0: + break + elif inter[0] < L <= inter[1]: + break + + intermed = (L - viri_intervals[idx][0]) / ( + viri_intervals[idx][1] - viri_intervals[idx][0] + ) + + float_color = clrs.find_intermediate_color( + viridis_colors[idx], viridis_colors[idx], intermed, colortype="rgb" + ) + + # make R,G,B into int values + float_color = clrs.unlabel_rgb(float_color) + float_color = clrs.unconvert_from_RGB_255(float_color) + int_rgb = clrs.convert_to_RGB_255(float_color) + int_rgb = clrs.label_rgb(int_rgb) + + colorscale.append(int_rgb) + + if len(colorscale) < len(LEVELS): + raise PlotlyError( + "You have {} LEVELS. Your number of colors in 'colorscale' must " + "be at least the number of LEVELS: {}. If you are " + "using 'binning_endpoints' then 'colorscale' must have at " + "least len(binning_endpoints) + 2 colors".format( + len(LEVELS), min(LEVELS, LEVELS[:20]) + ) + ) + + color_lookup = dict(zip(LEVELS, colorscale)) + x_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))])) + y_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))])) + + # scope + if isinstance(scope, str): + raise PlotlyError("'scope' must be a list/tuple/sequence") + + scope_names = [] + extra_states = [ + "Alaska", + "Commonwealth of the Northern Mariana Islands", + "Puerto Rico", + "Guam", + "United States Virgin Islands", + "American Samoa", + ] + for state in scope: + if state.lower() == "usa": + scope_names = df["STATE_NAME"].unique() + scope_names = list(scope_names) + for ex_st in extra_states: + try: + scope_names.remove(ex_st) + except ValueError: + pass + else: + if state in st_to_state_name_dict.keys(): + state = st_to_state_name_dict[state] + scope_names.append(state) + df_state = df_state[df_state["STATE_NAME"].isin(scope_names)] + + plot_data = [] + x_centroids = [] + y_centroids = [] + centroid_text = [] + fips_not_in_shapefile = [] + if not binning_endpoints: + for index, f in enumerate(fips): + level = values[index] + try: + fips_polygon_map[f].type + + ( + x_traces, + y_traces, + x_centroids, + y_centroids, + centroid_text, + ) = _calculations( + df, + fips, + values, + index, + f, + simplify_county, + level, + x_centroids, + y_centroids, + centroid_text, + x_traces, + y_traces, + fips_polygon_map, + ) + except KeyError: + fips_not_in_shapefile.append(f) + + else: + for index, f in enumerate(fips): + for j, inter in enumerate(intervals): + if inter[0] < values[index] <= inter[1]: + break + level = LEVELS[j] + + try: + fips_polygon_map[f].type + + ( + x_traces, + y_traces, + x_centroids, + y_centroids, + centroid_text, + ) = _calculations( + df, + fips, + values, + index, + f, + simplify_county, + level, + x_centroids, + y_centroids, + centroid_text, + x_traces, + y_traces, + fips_polygon_map, + ) + except KeyError: + fips_not_in_shapefile.append(f) + + if len(fips_not_in_shapefile) > 0: + msg = ( + "Unrecognized FIPS Values\n\nWhoops! It looks like you are " + "trying to pass at least one FIPS value that is not in " + "our shapefile of FIPS and data for the counties. Your " + "choropleth will still show up but these counties cannot " + "be shown.\nUnrecognized FIPS are: {}".format(fips_not_in_shapefile) + ) + warnings.warn(msg) + + x_states = [] + y_states = [] + for index, row in df_state.iterrows(): + if df_state["geometry"][index].type == "Polygon": + x = row.geometry.simplify(simplify_state).exterior.xy[0].tolist() + y = row.geometry.simplify(simplify_state).exterior.xy[1].tolist() + x_states = x_states + x + y_states = y_states + y + elif df_state["geometry"][index].type == "MultiPolygon": + x = [ + poly.simplify(simplify_state).exterior.xy[0].tolist() + for poly in df_state["geometry"][index].geoms + ] + y = [ + poly.simplify(simplify_state).exterior.xy[1].tolist() + for poly in df_state["geometry"][index].geoms + ] + for segment in range(len(x)): + x_states = x_states + x[segment] + y_states = y_states + y[segment] + x_states.append(np.nan) + y_states.append(np.nan) + x_states.append(np.nan) + y_states.append(np.nan) + + for lev in LEVELS: + county_data = dict( + type="scatter", + mode="lines", + x=x_traces[lev], + y=y_traces[lev], + line=county_outline, + fill="toself", + fillcolor=color_lookup[lev], + name=lev, + hoverinfo="none", + ) + plot_data.append(county_data) + + if show_hover: + hover_points = dict( + type="scatter", + showlegend=False, + legendgroup="centroids", + x=x_centroids, + y=y_centroids, + text=centroid_text, + name="US Counties", + mode="markers", + marker={"color": "white", "opacity": 0}, + hoverinfo="text", + ) + centroids_on_select = dict( + selected=dict(marker=centroid_marker), + unselected=dict(marker=dict(opacity=0)), + ) + hover_points.update(centroids_on_select) + plot_data.append(hover_points) + + if show_state_data: + state_data = dict( + type="scatter", + legendgroup="States", + line=state_outline, + x=x_states, + y=y_states, + hoverinfo="text", + showlegend=False, + mode="lines", + ) + plot_data.append(state_data) + + DEFAULT_LAYOUT = dict( + hovermode="closest", + xaxis=dict( + autorange=False, + range=USA_XRANGE, + showgrid=False, + zeroline=False, + fixedrange=True, + showticklabels=False, + ), + yaxis=dict( + autorange=False, + range=USA_YRANGE, + showgrid=False, + zeroline=False, + fixedrange=True, + showticklabels=False, + ), + margin=dict(t=40, b=20, r=20, l=20), + width=900, + height=450, + dragmode="select", + legend=dict(traceorder="reversed", xanchor="right", yanchor="top", x=1, y=1), + annotations=[], + ) + fig = dict(data=plot_data, layout=DEFAULT_LAYOUT) + fig["layout"].update(layout_options) + fig["layout"]["annotations"].append( + dict( + x=1, + y=1.05, + xref="paper", + yref="paper", + xanchor="right", + showarrow=False, + text="<b>" + legend_title + "</b>", + ) + ) + + if len(scope) == 1 and scope[0].lower() == "usa": + xaxis_range_low = -125.0 + xaxis_range_high = -55.0 + yaxis_range_low = 25.0 + yaxis_range_high = 49.0 + else: + xaxis_range_low = float("inf") + xaxis_range_high = float("-inf") + yaxis_range_low = float("inf") + yaxis_range_high = float("-inf") + for trace in fig["data"]: + if all(isinstance(n, Number) for n in trace["x"]): + calc_x_min = min(trace["x"] or [float("inf")]) + calc_x_max = max(trace["x"] or [float("-inf")]) + if calc_x_min < xaxis_range_low: + xaxis_range_low = calc_x_min + if calc_x_max > xaxis_range_high: + xaxis_range_high = calc_x_max + if all(isinstance(n, Number) for n in trace["y"]): + calc_y_min = min(trace["y"] or [float("inf")]) + calc_y_max = max(trace["y"] or [float("-inf")]) + if calc_y_min < yaxis_range_low: + yaxis_range_low = calc_y_min + if calc_y_max > yaxis_range_high: + yaxis_range_high = calc_y_max + + # camera zoom + fig["layout"]["xaxis"]["range"] = [xaxis_range_low, xaxis_range_high] + fig["layout"]["yaxis"]["range"] = [yaxis_range_low, yaxis_range_high] + + # aspect ratio + if asp is None: + usa_x_range = USA_XRANGE[1] - USA_XRANGE[0] + usa_y_range = USA_YRANGE[1] - USA_YRANGE[0] + asp = usa_x_range / usa_y_range + + # based on your figure + width = float( + fig["layout"]["xaxis"]["range"][1] - fig["layout"]["xaxis"]["range"][0] + ) + height = float( + fig["layout"]["yaxis"]["range"][1] - fig["layout"]["yaxis"]["range"][0] + ) + + center = ( + sum(fig["layout"]["xaxis"]["range"]) / 2.0, + sum(fig["layout"]["yaxis"]["range"]) / 2.0, + ) + + if height / width > (1 / asp): + new_width = asp * height + fig["layout"]["xaxis"]["range"][0] = center[0] - new_width * 0.5 + fig["layout"]["xaxis"]["range"][1] = center[0] + new_width * 0.5 + else: + new_height = (1 / asp) * width + fig["layout"]["yaxis"]["range"][0] = center[1] - new_height * 0.5 + fig["layout"]["yaxis"]["range"][1] = center[1] + new_height * 0.5 + + return go.Figure(fig) diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py new file mode 100644 index 0000000..fd6d505 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_dendrogram.py @@ -0,0 +1,395 @@ +from collections import OrderedDict + +from plotly import exceptions, optional_imports +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module("numpy") +scp = optional_imports.get_module("scipy") +sch = optional_imports.get_module("scipy.cluster.hierarchy") +scs = optional_imports.get_module("scipy.spatial") + + +def create_dendrogram( + X, + orientation="bottom", + labels=None, + colorscale=None, + distfun=None, + linkagefun=lambda x: sch.linkage(x, "complete"), + hovertext=None, + color_threshold=None, +): + """ + Function that returns a dendrogram Plotly figure object. This is a thin + wrapper around scipy.cluster.hierarchy.dendrogram. + + See also https://dash.plot.ly/dash-bio/clustergram. + + :param (ndarray) X: Matrix of observations as array of arrays + :param (str) orientation: 'top', 'right', 'bottom', or 'left' + :param (list) labels: List of axis category labels(observation labels) + :param (list) colorscale: Optional colorscale for the dendrogram tree. + Requires 8 colors to be specified, the 7th of + which is ignored. With scipy>=1.5.0, the 2nd, 3rd + and 6th are used twice as often as the others. + Given a shorter list, the missing values are + replaced with defaults and with a longer list the + extra values are ignored. + :param (function) distfun: Function to compute the pairwise distance from + the observations + :param (function) linkagefun: Function to compute the linkage matrix from + the pairwise distances + :param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram + clusters + :param (double) color_threshold: Value at which the separation of clusters will be made + + Example 1: Simple bottom oriented dendrogram + + >>> from plotly.figure_factory import create_dendrogram + + >>> import numpy as np + + >>> X = np.random.rand(10,10) + >>> fig = create_dendrogram(X) + >>> fig.show() + + Example 2: Dendrogram to put on the left of the heatmap + + >>> from plotly.figure_factory import create_dendrogram + + >>> import numpy as np + + >>> X = np.random.rand(5,5) + >>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark'] + >>> dendro = create_dendrogram(X, orientation='right', labels=names) + >>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP + >>> dendro.show() + + Example 3: Dendrogram with Pandas + + >>> from plotly.figure_factory import create_dendrogram + + >>> import numpy as np + >>> import pandas as pd + + >>> Index= ['A','B','C','D','E','F','G','H','I','J'] + >>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index) + >>> fig = create_dendrogram(df, labels=Index) + >>> fig.show() + """ + if not scp or not scs or not sch: + raise ImportError( + "FigureFactory.create_dendrogram requires scipy, \ + scipy.spatial and scipy.hierarchy" + ) + + s = X.shape + if len(s) != 2: + exceptions.PlotlyError("X should be 2-dimensional array.") + + if distfun is None: + distfun = scs.distance.pdist + + dendrogram = _Dendrogram( + X, + orientation, + labels, + colorscale, + distfun=distfun, + linkagefun=linkagefun, + hovertext=hovertext, + color_threshold=color_threshold, + ) + + return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout) + + +class _Dendrogram(object): + """Refer to FigureFactory.create_dendrogram() for docstring.""" + + def __init__( + self, + X, + orientation="bottom", + labels=None, + colorscale=None, + width=np.inf, + height=np.inf, + xaxis="xaxis", + yaxis="yaxis", + distfun=None, + linkagefun=lambda x: sch.linkage(x, "complete"), + hovertext=None, + color_threshold=None, + ): + self.orientation = orientation + self.labels = labels + self.xaxis = xaxis + self.yaxis = yaxis + self.data = [] + self.leaves = [] + self.sign = {self.xaxis: 1, self.yaxis: 1} + self.layout = {self.xaxis: {}, self.yaxis: {}} + + if self.orientation in ["left", "bottom"]: + self.sign[self.xaxis] = 1 + else: + self.sign[self.xaxis] = -1 + + if self.orientation in ["right", "bottom"]: + self.sign[self.yaxis] = 1 + else: + self.sign[self.yaxis] = -1 + + if distfun is None: + distfun = scs.distance.pdist + + (dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces( + X, colorscale, distfun, linkagefun, hovertext, color_threshold + ) + + self.labels = ordered_labels + self.leaves = leaves + yvals_flat = yvals.flatten() + xvals_flat = xvals.flatten() + + self.zero_vals = [] + + for i in range(len(yvals_flat)): + if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals: + self.zero_vals.append(xvals_flat[i]) + + if len(self.zero_vals) > len(yvals) + 1: + # If the length of zero_vals is larger than the length of yvals, + # it means that there are wrong vals because of the identicial samples. + # Three and more identicial samples will make the yvals of spliting + # center into 0 and it will accidentally take it as leaves. + l_border = int(min(self.zero_vals)) + r_border = int(max(self.zero_vals)) + correct_leaves_pos = range( + l_border, r_border + 1, int((r_border - l_border) / len(yvals)) + ) + # Regenerating the leaves pos from the self.zero_vals with equally intervals. + self.zero_vals = [v for v in correct_leaves_pos] + + self.zero_vals.sort() + self.layout = self.set_figure_layout(width, height) + self.data = dd_traces + + def get_color_dict(self, colorscale): + """ + Returns colorscale used for dendrogram tree clusters. + + :param (list) colorscale: Colors to use for the plot in rgb format. + :rtype (dict): A dict of default colors mapped to the user colorscale. + + """ + + # These are the color codes returned for dendrograms + # We're replacing them with nicer colors + # This list is the colors that can be used by dendrogram, which were + # determined as the combination of the default above_threshold_color and + # the default color palette (see scipy/cluster/hierarchy.py) + d = { + "r": "red", + "g": "green", + "b": "blue", + "c": "cyan", + "m": "magenta", + "y": "yellow", + "k": "black", + # TODO: 'w' doesn't seem to be in the default color + # palette in scipy/cluster/hierarchy.py + "w": "white", + } + default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) + + if colorscale is None: + rgb_colorscale = [ + "rgb(0,116,217)", # blue + "rgb(35,205,205)", # cyan + "rgb(61,153,112)", # green + "rgb(40,35,35)", # black + "rgb(133,20,75)", # magenta + "rgb(255,65,54)", # red + "rgb(255,255,255)", # white + "rgb(255,220,0)", # yellow + ] + else: + rgb_colorscale = colorscale + + for i in range(len(default_colors.keys())): + k = list(default_colors.keys())[i] # PY3 won't index keys + if i < len(rgb_colorscale): + default_colors[k] = rgb_colorscale[i] + + # add support for cyclic format colors as introduced in scipy===1.5.0 + # before this, the colors were named 'r', 'b', 'y' etc., now they are + # named 'C0', 'C1', etc. To keep the colors consistent regardless of the + # scipy version, we try as much as possible to map the new colors to the + # old colors + # this mapping was found by inpecting scipy/cluster/hierarchy.py (see + # comment above). + new_old_color_map = [ + ("C0", "b"), + ("C1", "g"), + ("C2", "r"), + ("C3", "c"), + ("C4", "m"), + ("C5", "y"), + ("C6", "k"), + ("C7", "g"), + ("C8", "r"), + ("C9", "c"), + ] + for nc, oc in new_old_color_map: + try: + default_colors[nc] = default_colors[oc] + except KeyError: + # it could happen that the old color isn't found (if a custom + # colorscale was specified), in this case we set it to an + # arbitrary default. + default_colors[nc] = "rgb(0,116,217)" + + return default_colors + + def set_axis_layout(self, axis_key): + """ + Sets and returns default axis object for dendrogram figure. + + :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc. + :rtype (dict): An axis_key dictionary with set parameters. + + """ + axis_defaults = { + "type": "linear", + "ticks": "outside", + "mirror": "allticks", + "rangemode": "tozero", + "showticklabels": True, + "zeroline": False, + "showgrid": False, + "showline": True, + } + + if len(self.labels) != 0: + axis_key_labels = self.xaxis + if self.orientation in ["left", "right"]: + axis_key_labels = self.yaxis + if axis_key_labels not in self.layout: + self.layout[axis_key_labels] = {} + self.layout[axis_key_labels]["tickvals"] = [ + zv * self.sign[axis_key] for zv in self.zero_vals + ] + self.layout[axis_key_labels]["ticktext"] = self.labels + self.layout[axis_key_labels]["tickmode"] = "array" + + self.layout[axis_key].update(axis_defaults) + + return self.layout[axis_key] + + def set_figure_layout(self, width, height): + """ + Sets and returns default layout object for dendrogram figure. + + """ + self.layout.update( + { + "showlegend": False, + "autosize": False, + "hovermode": "closest", + "width": width, + "height": height, + } + ) + + self.set_axis_layout(self.xaxis) + self.set_axis_layout(self.yaxis) + + return self.layout + + def get_dendrogram_traces( + self, X, colorscale, distfun, linkagefun, hovertext, color_threshold + ): + """ + Calculates all the elements needed for plotting a dendrogram. + + :param (ndarray) X: Matrix of observations as array of arrays + :param (list) colorscale: Color scale for dendrogram tree clusters + :param (function) distfun: Function to compute the pairwise distance + from the observations + :param (function) linkagefun: Function to compute the linkage matrix + from the pairwise distances + :param (list) hovertext: List of hovertext for constituent traces of dendrogram + :rtype (tuple): Contains all the traces in the following order: + (a) trace_list: List of Plotly trace objects for dendrogram tree + (b) icoord: All X points of the dendrogram tree as array of arrays + with length 4 + (c) dcoord: All Y points of the dendrogram tree as array of arrays + with length 4 + (d) ordered_labels: leaf labels in the order they are going to + appear on the plot + (e) P['leaves']: left-to-right traversal of the leaves + + """ + d = distfun(X) + Z = linkagefun(d) + P = sch.dendrogram( + Z, + orientation=self.orientation, + labels=self.labels, + no_plot=True, + color_threshold=color_threshold, + ) + + icoord = np.array(P["icoord"]) + dcoord = np.array(P["dcoord"]) + ordered_labels = np.array(P["ivl"]) + color_list = np.array(P["color_list"]) + colors = self.get_color_dict(colorscale) + + trace_list = [] + + for i in range(len(icoord)): + # xs and ys are arrays of 4 points that make up the '∩' shapes + # of the dendrogram tree + if self.orientation in ["top", "bottom"]: + xs = icoord[i] + else: + xs = dcoord[i] + + if self.orientation in ["top", "bottom"]: + ys = dcoord[i] + else: + ys = icoord[i] + color_key = color_list[i] + hovertext_label = None + if hovertext: + hovertext_label = hovertext[i] + trace = dict( + type="scatter", + x=np.multiply(self.sign[self.xaxis], xs), + y=np.multiply(self.sign[self.yaxis], ys), + mode="lines", + marker=dict(color=colors[color_key]), + text=hovertext_label, + hoverinfo="text", + ) + + try: + x_index = int(self.xaxis[-1]) + except ValueError: + x_index = "" + + try: + y_index = int(self.yaxis[-1]) + except ValueError: + y_index = "" + + trace["xaxis"] = f"x{x_index}" + trace["yaxis"] = f"y{y_index}" + + trace_list.append(trace) + + return trace_list, icoord, dcoord, ordered_labels, P["leaves"] diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py new file mode 100644 index 0000000..73f6609 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_distplot.py @@ -0,0 +1,441 @@ +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module("numpy") +pd = optional_imports.get_module("pandas") +scipy = optional_imports.get_module("scipy") +scipy_stats = optional_imports.get_module("scipy.stats") + + +DEFAULT_HISTNORM = "probability density" +ALTERNATIVE_HISTNORM = "probability" + + +def validate_distplot(hist_data, curve_type): + """ + Distplot-specific validations + + :raises: (PlotlyError) If hist_data is not a list of lists + :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or + 'normal'). + """ + hist_data_types = (list,) + if np: + hist_data_types += (np.ndarray,) + if pd: + hist_data_types += (pd.core.series.Series,) + + if not isinstance(hist_data[0], hist_data_types): + raise exceptions.PlotlyError( + "Oops, this function was written " + "to handle multiple datasets, if " + "you want to plot just one, make " + "sure your hist_data variable is " + "still a list of lists, i.e. x = " + "[1, 2, 3] -> x = [[1, 2, 3]]" + ) + + curve_opts = ("kde", "normal") + if curve_type not in curve_opts: + raise exceptions.PlotlyError("curve_type must be defined as 'kde' or 'normal'") + + if not scipy: + raise ImportError("FigureFactory.create_distplot requires scipy") + + +def create_distplot( + hist_data, + group_labels, + bin_size=1.0, + curve_type="kde", + colors=None, + rug_text=None, + histnorm=DEFAULT_HISTNORM, + show_hist=True, + show_curve=True, + show_rug=True, +): + """ + Function that creates a distplot similar to seaborn.distplot; + **this function is deprecated**, use instead :mod:`plotly.express` + functions, for example + + >>> import plotly.express as px + >>> tips = px.data.tips() + >>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug", + ... hover_data=tips.columns) + >>> fig.show() + + + The distplot can be composed of all or any combination of the following + 3 components: (1) histogram, (2) curve: (a) kernel density estimation + or (b) normal curve, and (3) rug plot. Additionally, multiple distplots + (from multiple datasets) can be created in the same plot. + + :param (list[list]) hist_data: Use list of lists to plot multiple data + sets on the same plot. + :param (list[str]) group_labels: Names for each data set. + :param (list[float]|float) bin_size: Size of histogram bins. + Default = 1. + :param (str) curve_type: 'kde' or 'normal'. Default = 'kde' + :param (str) histnorm: 'probability density' or 'probability' + Default = 'probability density' + :param (bool) show_hist: Add histogram to distplot? Default = True + :param (bool) show_curve: Add curve to distplot? Default = True + :param (bool) show_rug: Add rug to distplot? Default = True + :param (list[str]) colors: Colors for traces. + :param (list[list]) rug_text: Hovertext values for rug_plot, + :return (dict): Representation of a distplot figure. + + Example 1: Simple distplot of 1 data set + + >>> from plotly.figure_factory import create_distplot + + >>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, + ... 3.5, 4.1, 4.4, 4.5, 4.5, + ... 5.0, 5.0, 5.2, 5.5, 5.5, + ... 5.5, 5.5, 5.5, 6.1, 7.0]] + >>> group_labels = ['distplot example'] + >>> fig = create_distplot(hist_data, group_labels) + >>> fig.show() + + + Example 2: Two data sets and added rug text + + >>> from plotly.figure_factory import create_distplot + >>> # Add histogram data + >>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, + ... -0.9, -0.07, 1.95, 0.9, -0.2, + ... -0.5, 0.3, 0.4, -0.37, 0.6] + >>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, + ... 1.0, 0.8, 1.7, 0.5, 0.8, + ... -0.3, 1.2, 0.56, 0.3, 2.2] + + >>> # Group data together + >>> hist_data = [hist1_x, hist2_x] + + >>> group_labels = ['2012', '2013'] + + >>> # Add text + >>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1', + ... 'f1', 'g1', 'h1', 'i1', 'j1', + ... 'k1', 'l1', 'm1', 'n1', 'o1'] + + >>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2', + ... 'f2', 'g2', 'h2', 'i2', 'j2', + ... 'k2', 'l2', 'm2', 'n2', 'o2'] + + >>> # Group text together + >>> rug_text_all = [rug_text_1, rug_text_2] + + >>> # Create distplot + >>> fig = create_distplot( + ... hist_data, group_labels, rug_text=rug_text_all, bin_size=.2) + + >>> # Add title + >>> fig.update_layout(title='Dist Plot') # doctest: +SKIP + >>> fig.show() + + + Example 3: Plot with normal curve and hide rug plot + + >>> from plotly.figure_factory import create_distplot + >>> import numpy as np + + >>> x1 = np.random.randn(190) + >>> x2 = np.random.randn(200)+1 + >>> x3 = np.random.randn(200)-1 + >>> x4 = np.random.randn(210)+2 + + >>> hist_data = [x1, x2, x3, x4] + >>> group_labels = ['2012', '2013', '2014', '2015'] + + >>> fig = create_distplot( + ... hist_data, group_labels, curve_type='normal', + ... show_rug=False, bin_size=.4) + + + Example 4: Distplot with Pandas + + >>> from plotly.figure_factory import create_distplot + >>> import numpy as np + >>> import pandas as pd + + >>> df = pd.DataFrame({'2012': np.random.randn(200), + ... '2013': np.random.randn(200)+1}) + >>> fig = create_distplot([df[c] for c in df.columns], df.columns) + >>> fig.show() + """ + if colors is None: + colors = [] + if rug_text is None: + rug_text = [] + + validate_distplot(hist_data, curve_type) + utils.validate_equal_length(hist_data, group_labels) + + if isinstance(bin_size, (float, int)): + bin_size = [bin_size] * len(hist_data) + + data = [] + if show_hist: + hist = _Distplot( + hist_data, + histnorm, + group_labels, + bin_size, + curve_type, + colors, + rug_text, + show_hist, + show_curve, + ).make_hist() + + data.append(hist) + + if show_curve: + if curve_type == "normal": + curve = _Distplot( + hist_data, + histnorm, + group_labels, + bin_size, + curve_type, + colors, + rug_text, + show_hist, + show_curve, + ).make_normal() + else: + curve = _Distplot( + hist_data, + histnorm, + group_labels, + bin_size, + curve_type, + colors, + rug_text, + show_hist, + show_curve, + ).make_kde() + + data.append(curve) + + if show_rug: + rug = _Distplot( + hist_data, + histnorm, + group_labels, + bin_size, + curve_type, + colors, + rug_text, + show_hist, + show_curve, + ).make_rug() + + data.append(rug) + layout = graph_objs.Layout( + barmode="overlay", + hovermode="closest", + legend=dict(traceorder="reversed"), + xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False), + yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0), + yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False), + ) + else: + layout = graph_objs.Layout( + barmode="overlay", + hovermode="closest", + legend=dict(traceorder="reversed"), + xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False), + yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0), + ) + + data = sum(data, []) + return graph_objs.Figure(data=data, layout=layout) + + +class _Distplot(object): + """ + Refer to TraceFactory.create_distplot() for docstring + """ + + def __init__( + self, + hist_data, + histnorm, + group_labels, + bin_size, + curve_type, + colors, + rug_text, + show_hist, + show_curve, + ): + self.hist_data = hist_data + self.histnorm = histnorm + self.group_labels = group_labels + self.bin_size = bin_size + self.show_hist = show_hist + self.show_curve = show_curve + self.trace_number = len(hist_data) + if rug_text: + self.rug_text = rug_text + else: + self.rug_text = [None] * self.trace_number + + self.start = [] + self.end = [] + if colors: + self.colors = colors + else: + self.colors = [ + "rgb(31, 119, 180)", + "rgb(255, 127, 14)", + "rgb(44, 160, 44)", + "rgb(214, 39, 40)", + "rgb(148, 103, 189)", + "rgb(140, 86, 75)", + "rgb(227, 119, 194)", + "rgb(127, 127, 127)", + "rgb(188, 189, 34)", + "rgb(23, 190, 207)", + ] + self.curve_x = [None] * self.trace_number + self.curve_y = [None] * self.trace_number + + for trace in self.hist_data: + self.start.append(min(trace) * 1.0) + self.end.append(max(trace) * 1.0) + + def make_hist(self): + """ + Makes the histogram(s) for FigureFactory.create_distplot(). + + :rtype (list) hist: list of histogram representations + """ + hist = [None] * self.trace_number + + for index in range(self.trace_number): + hist[index] = dict( + type="histogram", + x=self.hist_data[index], + xaxis="x1", + yaxis="y1", + histnorm=self.histnorm, + name=self.group_labels[index], + legendgroup=self.group_labels[index], + marker=dict(color=self.colors[index % len(self.colors)]), + autobinx=False, + xbins=dict( + start=self.start[index], + end=self.end[index], + size=self.bin_size[index], + ), + opacity=0.7, + ) + return hist + + def make_kde(self): + """ + Makes the kernel density estimation(s) for create_distplot(). + + This is called when curve_type = 'kde' in create_distplot(). + + :rtype (list) curve: list of kde representations + """ + curve = [None] * self.trace_number + for index in range(self.trace_number): + self.curve_x[index] = [ + self.start[index] + x * (self.end[index] - self.start[index]) / 500 + for x in range(500) + ] + self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])( + self.curve_x[index] + ) + + if self.histnorm == ALTERNATIVE_HISTNORM: + self.curve_y[index] *= self.bin_size[index] + + for index in range(self.trace_number): + curve[index] = dict( + type="scatter", + x=self.curve_x[index], + y=self.curve_y[index], + xaxis="x1", + yaxis="y1", + mode="lines", + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index % len(self.colors)]), + ) + return curve + + def make_normal(self): + """ + Makes the normal curve(s) for create_distplot(). + + This is called when curve_type = 'normal' in create_distplot(). + + :rtype (list) curve: list of normal curve representations + """ + curve = [None] * self.trace_number + mean = [None] * self.trace_number + sd = [None] * self.trace_number + + for index in range(self.trace_number): + mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index]) + self.curve_x[index] = [ + self.start[index] + x * (self.end[index] - self.start[index]) / 500 + for x in range(500) + ] + self.curve_y[index] = scipy_stats.norm.pdf( + self.curve_x[index], loc=mean[index], scale=sd[index] + ) + + if self.histnorm == ALTERNATIVE_HISTNORM: + self.curve_y[index] *= self.bin_size[index] + + for index in range(self.trace_number): + curve[index] = dict( + type="scatter", + x=self.curve_x[index], + y=self.curve_y[index], + xaxis="x1", + yaxis="y1", + mode="lines", + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index % len(self.colors)]), + ) + return curve + + def make_rug(self): + """ + Makes the rug plot(s) for create_distplot(). + + :rtype (list) rug: list of rug plot representations + """ + rug = [None] * self.trace_number + for index in range(self.trace_number): + rug[index] = dict( + type="scatter", + x=self.hist_data[index], + y=([self.group_labels[index]] * len(self.hist_data[index])), + xaxis="x1", + yaxis="y2", + mode="markers", + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=(False if self.show_hist or self.show_curve else True), + text=self.rug_text[index], + marker=dict( + color=self.colors[index % len(self.colors)], symbol="line-ns-open" + ), + ) + return rug diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py new file mode 100644 index 0000000..06dc71d --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_facet_grid.py @@ -0,0 +1,1195 @@ +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.figure_factory import utils +from plotly.subplots import make_subplots + +import math +from numbers import Number + +pd = optional_imports.get_module("pandas") + +TICK_COLOR = "#969696" +AXIS_TITLE_COLOR = "#0f0f0f" +AXIS_TITLE_SIZE = 12 +GRID_COLOR = "#ffffff" +LEGEND_COLOR = "#efefef" +PLOT_BGCOLOR = "#ededed" +ANNOT_RECT_COLOR = "#d0d0d0" +LEGEND_BORDER_WIDTH = 1 +LEGEND_ANNOT_X = 1.05 +LEGEND_ANNOT_Y = 0.5 +MAX_TICKS_PER_AXIS = 5 +THRES_FOR_FLIPPED_FACET_TITLES = 10 +GRID_WIDTH = 1 + +VALID_TRACE_TYPES = ["scatter", "scattergl", "histogram", "bar", "box"] + +CUSTOM_LABEL_ERROR = ( + "If you are using a dictionary for custom labels for the facet row/col, " + "make sure each key in that column of the dataframe is in your facet " + "labels. The keys you need are {}" +) + + +def _is_flipped(num): + if num >= THRES_FOR_FLIPPED_FACET_TITLES: + flipped = True + else: + flipped = False + return flipped + + +def _return_label(original_label, facet_labels, facet_var): + if isinstance(facet_labels, dict): + label = facet_labels[original_label] + elif isinstance(facet_labels, str): + label = "{}: {}".format(facet_var, original_label) + else: + label = original_label + return label + + +def _legend_annotation(color_name): + legend_title = dict( + textangle=0, + xanchor="left", + yanchor="middle", + x=LEGEND_ANNOT_X, + y=1.03, + showarrow=False, + xref="paper", + yref="paper", + text="factor({})".format(color_name), + font=dict(size=13, color="#000000"), + ) + return legend_title + + +def _annotation_dict( + text, lane, num_of_lanes, SUBPLOT_SPACING, row_col="col", flipped=True +): + temp = (1 - (num_of_lanes - 1) * SUBPLOT_SPACING) / (num_of_lanes) + if not flipped: + xanchor = "center" + yanchor = "middle" + if row_col == "col": + x = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp + y = 1.03 + textangle = 0 + elif row_col == "row": + y = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp + x = 1.03 + textangle = 90 + else: + if row_col == "col": + xanchor = "center" + yanchor = "bottom" + x = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp + y = 1.0 + textangle = 270 + elif row_col == "row": + xanchor = "left" + yanchor = "middle" + y = (lane - 1) * (temp + SUBPLOT_SPACING) + 0.5 * temp + x = 1.0 + textangle = 0 + + annotation_dict = dict( + textangle=textangle, + xanchor=xanchor, + yanchor=yanchor, + x=x, + y=y, + showarrow=False, + xref="paper", + yref="paper", + text=str(text), + font=dict(size=13, color=AXIS_TITLE_COLOR), + ) + return annotation_dict + + +def _axis_title_annotation(text, x_or_y_axis): + if x_or_y_axis == "x": + x_pos = 0.5 + y_pos = -0.1 + textangle = 0 + elif x_or_y_axis == "y": + x_pos = -0.1 + y_pos = 0.5 + textangle = 270 + + if not text: + text = "" + + annot = { + "font": {"color": "#000000", "size": AXIS_TITLE_SIZE}, + "showarrow": False, + "text": text, + "textangle": textangle, + "x": x_pos, + "xanchor": "center", + "xref": "paper", + "y": y_pos, + "yanchor": "middle", + "yref": "paper", + } + return annot + + +def _add_shapes_to_fig(fig, annot_rect_color, flipped_rows=False, flipped_cols=False): + shapes_list = [] + for key in fig["layout"].to_plotly_json().keys(): + if "axis" in key and fig["layout"][key]["domain"] != [0.0, 1.0]: + shape = { + "fillcolor": annot_rect_color, + "layer": "below", + "line": {"color": annot_rect_color, "width": 1}, + "type": "rect", + "xref": "paper", + "yref": "paper", + } + + if "xaxis" in key: + shape["x0"] = fig["layout"][key]["domain"][0] + shape["x1"] = fig["layout"][key]["domain"][1] + shape["y0"] = 1.005 + shape["y1"] = 1.05 + + if flipped_cols: + shape["y1"] += 0.5 + shapes_list.append(shape) + + elif "yaxis" in key: + shape["x0"] = 1.005 + shape["x1"] = 1.05 + shape["y0"] = fig["layout"][key]["domain"][0] + shape["y1"] = fig["layout"][key]["domain"][1] + + if flipped_rows: + shape["x1"] += 1 + shapes_list.append(shape) + + fig["layout"]["shapes"] = shapes_list + + +def _make_trace_for_scatter(trace, trace_type, color, **kwargs_marker): + if trace_type in ["scatter", "scattergl"]: + trace["mode"] = "markers" + trace["marker"] = dict(color=color, **kwargs_marker) + return trace + + +def _facet_grid_color_categorical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colormap, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, +): + fig = make_subplots( + rows=num_of_rows, + cols=num_of_cols, + shared_xaxes=True, + shared_yaxes=True, + horizontal_spacing=SUBPLOT_SPACING, + vertical_spacing=SUBPLOT_SPACING, + print_grid=False, + ) + + annotations = [] + if not facet_row and not facet_col: + color_groups = list(df.groupby(color_name)) + for group in color_groups: + trace = dict( + type=trace_type, + name=group[0], + marker=dict(color=colormap[group[0]]), + **kwargs_trace, + ) + if x: + trace["x"] = group[1][x] + if y: + trace["y"] = group[1][y] + trace = _make_trace_for_scatter( + trace, trace_type, colormap[group[0]], **kwargs_marker + ) + + fig.append_trace(trace, 1, 1) + + elif (facet_row and not facet_col) or (not facet_row and facet_col): + groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col)) + for j, group in enumerate(groups_by_facet): + for color_val in df[color_name].unique(): + data_by_color = group[1][group[1][color_name] == color_val] + trace = dict( + type=trace_type, + name=color_val, + marker=dict(color=colormap[color_val]), + **kwargs_trace, + ) + if x: + trace["x"] = data_by_color[x] + if y: + trace["y"] = data_by_color[y] + trace = _make_trace_for_scatter( + trace, trace_type, colormap[color_val], **kwargs_marker + ) + + fig.append_trace( + trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1 + ) + + label = _return_label( + group[0], + facet_row_labels if facet_row else facet_col_labels, + facet_row if facet_row else facet_col, + ) + + annotations.append( + _annotation_dict( + label, + num_of_rows - j if facet_row else j + 1, + num_of_rows if facet_row else num_of_cols, + SUBPLOT_SPACING, + "row" if facet_row else "col", + flipped_rows, + ) + ) + + elif facet_row and facet_col: + groups_by_facets = list(df.groupby([facet_row, facet_col])) + tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets} + + row_values = df[facet_row].unique() + col_values = df[facet_col].unique() + color_vals = df[color_name].unique() + for row_count, x_val in enumerate(row_values): + for col_count, y_val in enumerate(col_values): + try: + group = tuple_to_facet_group[(x_val, y_val)] + except KeyError: + group = pd.DataFrame( + [[None, None, None]], columns=[x, y, color_name] + ) + + for color_val in color_vals: + if group.values.tolist() != [[None, None, None]]: + group_filtered = group[group[color_name] == color_val] + + trace = dict( + type=trace_type, + name=color_val, + marker=dict(color=colormap[color_val]), + **kwargs_trace, + ) + new_x = group_filtered[x] + new_y = group_filtered[y] + else: + trace = dict( + type=trace_type, + name=color_val, + marker=dict(color=colormap[color_val]), + showlegend=False, + **kwargs_trace, + ) + new_x = group[x] + new_y = group[y] + + if x: + trace["x"] = new_x + if y: + trace["y"] = new_y + trace = _make_trace_for_scatter( + trace, trace_type, colormap[color_val], **kwargs_marker + ) + + fig.append_trace(trace, row_count + 1, col_count + 1) + if row_count == 0: + label = _return_label( + col_values[col_count], facet_col_labels, facet_col + ) + annotations.append( + _annotation_dict( + label, + col_count + 1, + num_of_cols, + SUBPLOT_SPACING, + row_col="col", + flipped=flipped_cols, + ) + ) + label = _return_label(row_values[row_count], facet_row_labels, facet_row) + annotations.append( + _annotation_dict( + label, + num_of_rows - row_count, + num_of_rows, + SUBPLOT_SPACING, + row_col="row", + flipped=flipped_rows, + ) + ) + + return fig, annotations + + +def _facet_grid_color_numerical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colormap, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, +): + fig = make_subplots( + rows=num_of_rows, + cols=num_of_cols, + shared_xaxes=True, + shared_yaxes=True, + horizontal_spacing=SUBPLOT_SPACING, + vertical_spacing=SUBPLOT_SPACING, + print_grid=False, + ) + + annotations = [] + if not facet_row and not facet_col: + trace = dict( + type=trace_type, + marker=dict(color=df[color_name], colorscale=colormap, showscale=True), + **kwargs_trace, + ) + if x: + trace["x"] = df[x] + if y: + trace["y"] = df[y] + trace = _make_trace_for_scatter( + trace, trace_type, df[color_name], **kwargs_marker + ) + + fig.append_trace(trace, 1, 1) + + if (facet_row and not facet_col) or (not facet_row and facet_col): + groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col)) + for j, group in enumerate(groups_by_facet): + trace = dict( + type=trace_type, + marker=dict( + color=df[color_name], + colorscale=colormap, + showscale=True, + colorbar=dict(x=1.15), + ), + **kwargs_trace, + ) + if x: + trace["x"] = group[1][x] + if y: + trace["y"] = group[1][y] + trace = _make_trace_for_scatter( + trace, trace_type, df[color_name], **kwargs_marker + ) + + fig.append_trace( + trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1 + ) + + labels = facet_row_labels if facet_row else facet_col_labels + label = _return_label( + group[0], labels, facet_row if facet_row else facet_col + ) + + annotations.append( + _annotation_dict( + label, + num_of_rows - j if facet_row else j + 1, + num_of_rows if facet_row else num_of_cols, + SUBPLOT_SPACING, + "row" if facet_row else "col", + flipped=flipped_rows, + ) + ) + + elif facet_row and facet_col: + groups_by_facets = list(df.groupby([facet_row, facet_col])) + tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets} + + row_values = df[facet_row].unique() + col_values = df[facet_col].unique() + for row_count, x_val in enumerate(row_values): + for col_count, y_val in enumerate(col_values): + try: + group = tuple_to_facet_group[(x_val, y_val)] + except KeyError: + group = pd.DataFrame( + [[None, None, None]], columns=[x, y, color_name] + ) + + if group.values.tolist() != [[None, None, None]]: + trace = dict( + type=trace_type, + marker=dict( + color=df[color_name], + colorscale=colormap, + showscale=(row_count == 0), + colorbar=dict(x=1.15), + ), + **kwargs_trace, + ) + + else: + trace = dict(type=trace_type, showlegend=False, **kwargs_trace) + + if x: + trace["x"] = group[x] + if y: + trace["y"] = group[y] + trace = _make_trace_for_scatter( + trace, trace_type, df[color_name], **kwargs_marker + ) + + fig.append_trace(trace, row_count + 1, col_count + 1) + if row_count == 0: + label = _return_label( + col_values[col_count], facet_col_labels, facet_col + ) + annotations.append( + _annotation_dict( + label, + col_count + 1, + num_of_cols, + SUBPLOT_SPACING, + row_col="col", + flipped=flipped_cols, + ) + ) + label = _return_label(row_values[row_count], facet_row_labels, facet_row) + annotations.append( + _annotation_dict( + row_values[row_count], + num_of_rows - row_count, + num_of_rows, + SUBPLOT_SPACING, + row_col="row", + flipped=flipped_rows, + ) + ) + + return fig, annotations + + +def _facet_grid( + df, + x, + y, + facet_row, + facet_col, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, +): + fig = make_subplots( + rows=num_of_rows, + cols=num_of_cols, + shared_xaxes=True, + shared_yaxes=True, + horizontal_spacing=SUBPLOT_SPACING, + vertical_spacing=SUBPLOT_SPACING, + print_grid=False, + ) + annotations = [] + if not facet_row and not facet_col: + trace = dict( + type=trace_type, + marker=dict(color=marker_color, line=kwargs_marker["line"]), + **kwargs_trace, + ) + + if x: + trace["x"] = df[x] + if y: + trace["y"] = df[y] + trace = _make_trace_for_scatter( + trace, trace_type, marker_color, **kwargs_marker + ) + + fig.append_trace(trace, 1, 1) + + elif (facet_row and not facet_col) or (not facet_row and facet_col): + groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col)) + for j, group in enumerate(groups_by_facet): + trace = dict( + type=trace_type, + marker=dict(color=marker_color, line=kwargs_marker["line"]), + **kwargs_trace, + ) + + if x: + trace["x"] = group[1][x] + if y: + trace["y"] = group[1][y] + trace = _make_trace_for_scatter( + trace, trace_type, marker_color, **kwargs_marker + ) + + fig.append_trace( + trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1 + ) + + label = _return_label( + group[0], + facet_row_labels if facet_row else facet_col_labels, + facet_row if facet_row else facet_col, + ) + + annotations.append( + _annotation_dict( + label, + num_of_rows - j if facet_row else j + 1, + num_of_rows if facet_row else num_of_cols, + SUBPLOT_SPACING, + "row" if facet_row else "col", + flipped_rows, + ) + ) + + elif facet_row and facet_col: + groups_by_facets = list(df.groupby([facet_row, facet_col])) + tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets} + + row_values = df[facet_row].unique() + col_values = df[facet_col].unique() + for row_count, x_val in enumerate(row_values): + for col_count, y_val in enumerate(col_values): + try: + group = tuple_to_facet_group[(x_val, y_val)] + except KeyError: + group = pd.DataFrame([[None, None]], columns=[x, y]) + trace = dict( + type=trace_type, + marker=dict(color=marker_color, line=kwargs_marker["line"]), + **kwargs_trace, + ) + if x: + trace["x"] = group[x] + if y: + trace["y"] = group[y] + trace = _make_trace_for_scatter( + trace, trace_type, marker_color, **kwargs_marker + ) + + fig.append_trace(trace, row_count + 1, col_count + 1) + if row_count == 0: + label = _return_label( + col_values[col_count], facet_col_labels, facet_col + ) + annotations.append( + _annotation_dict( + label, + col_count + 1, + num_of_cols, + SUBPLOT_SPACING, + row_col="col", + flipped=flipped_cols, + ) + ) + + label = _return_label(row_values[row_count], facet_row_labels, facet_row) + annotations.append( + _annotation_dict( + label, + num_of_rows - row_count, + num_of_rows, + SUBPLOT_SPACING, + row_col="row", + flipped=flipped_rows, + ) + ) + + return fig, annotations + + +def create_facet_grid( + df, + x=None, + y=None, + facet_row=None, + facet_col=None, + color_name=None, + colormap=None, + color_is_cat=False, + facet_row_labels=None, + facet_col_labels=None, + height=None, + width=None, + trace_type="scatter", + scales="fixed", + dtick_x=None, + dtick_y=None, + show_boxes=True, + ggplot2=False, + binsize=1, + **kwargs, +): + """ + Returns figure for facet grid; **this function is deprecated**, since + plotly.express functions should be used instead, for example + + >>> import plotly.express as px + >>> tips = px.data.tips() + >>> fig = px.scatter(tips, + ... x='total_bill', + ... y='tip', + ... facet_row='sex', + ... facet_col='smoker', + ... color='size') + + + :param (pd.DataFrame) df: the dataframe of columns for the facet grid. + :param (str) x: the name of the dataframe column for the x axis data. + :param (str) y: the name of the dataframe column for the y axis data. + :param (str) facet_row: the name of the dataframe column that is used to + facet the grid into row panels. + :param (str) facet_col: the name of the dataframe column that is used to + facet the grid into column panels. + :param (str) color_name: the name of your dataframe column that will + function as the colormap variable. + :param (str|list|dict) colormap: the param that determines how the + color_name column colors the data. If the dataframe contains numeric + data, then a dictionary of colors will group the data categorically + while a Plotly Colorscale name or a custom colorscale will treat it + numerically. To learn more about colors and types of colormap, run + `help(plotly.colors)`. + :param (bool) color_is_cat: determines whether a numerical column for the + colormap will be treated as categorical (True) or sequential (False). + Default = False. + :param (str|dict) facet_row_labels: set to either 'name' or a dictionary + of all the unique values in the faceting row mapped to some text to + show up in the label annotations. If None, labeling works like usual. + :param (str|dict) facet_col_labels: set to either 'name' or a dictionary + of all the values in the faceting row mapped to some text to show up + in the label annotations. If None, labeling works like usual. + :param (int) height: the height of the facet grid figure. + :param (int) width: the width of the facet grid figure. + :param (str) trace_type: decides the type of plot to appear in the + facet grid. The options are 'scatter', 'scattergl', 'histogram', + 'bar', and 'box'. + Default = 'scatter'. + :param (str) scales: determines if axes have fixed ranges or not. Valid + settings are 'fixed' (all axes fixed), 'free_x' (x axis free only), + 'free_y' (y axis free only) or 'free' (both axes free). + :param (float) dtick_x: determines the distance between each tick on the + x-axis. Default is None which means dtick_x is set automatically. + :param (float) dtick_y: determines the distance between each tick on the + y-axis. Default is None which means dtick_y is set automatically. + :param (bool) show_boxes: draws grey boxes behind the facet titles. + :param (bool) ggplot2: draws the facet grid in the style of `ggplot2`. See + http://ggplot2.tidyverse.org/reference/facet_grid.html for reference. + Default = False + :param (int) binsize: groups all data into bins of a given length. + :param (dict) kwargs: a dictionary of scatterplot arguments. + + Examples 1: One Way Faceting + + >>> import plotly.figure_factory as ff + >>> import pandas as pd + >>> mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt') + + >>> fig = ff.create_facet_grid( + ... mpg, + ... x='displ', + ... y='cty', + ... facet_col='cyl', + ... ) + >>> fig.show() + + Example 2: Two Way Faceting + + >>> import plotly.figure_factory as ff + + >>> import pandas as pd + + >>> mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt') + + >>> fig = ff.create_facet_grid( + ... mpg, + ... x='displ', + ... y='cty', + ... facet_row='drv', + ... facet_col='cyl', + ... ) + >>> fig.show() + + Example 3: Categorical Coloring + + >>> import plotly.figure_factory as ff + >>> import pandas as pd + >>> mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv') + >>> mtcars.cyl = mtcars.cyl.astype(str) + >>> fig = ff.create_facet_grid( + ... mtcars, + ... x='mpg', + ... y='wt', + ... facet_col='cyl', + ... color_name='cyl', + ... color_is_cat=True, + ... ) + >>> fig.show() + + + """ + if not pd: + raise ImportError("'pandas' must be installed for this figure_factory.") + + if not isinstance(df, pd.DataFrame): + raise exceptions.PlotlyError("You must input a pandas DataFrame.") + + # make sure all columns are of homogenous datatype + utils.validate_dataframe(df) + + if trace_type in ["scatter", "scattergl"]: + if not x or not y: + raise exceptions.PlotlyError( + "You need to input 'x' and 'y' if you are you are using a " + "trace_type of 'scatter' or 'scattergl'." + ) + + for key in [x, y, facet_row, facet_col, color_name]: + if key is not None: + try: + df[key] + except KeyError: + raise exceptions.PlotlyError( + "x, y, facet_row, facet_col and color_name must be keys " + "in your dataframe." + ) + # autoscale histogram bars + if trace_type not in ["scatter", "scattergl"]: + scales = "free" + + # validate scales + if scales not in ["fixed", "free_x", "free_y", "free"]: + raise exceptions.PlotlyError( + "'scales' must be set to 'fixed', 'free_x', 'free_y' and 'free'." + ) + + if trace_type not in VALID_TRACE_TYPES: + raise exceptions.PlotlyError( + "'trace_type' must be in {}".format(VALID_TRACE_TYPES) + ) + + if trace_type == "histogram": + SUBPLOT_SPACING = 0.06 + else: + SUBPLOT_SPACING = 0.015 + + # seperate kwargs for marker and else + if "marker" in kwargs: + kwargs_marker = kwargs["marker"] + else: + kwargs_marker = {} + marker_color = kwargs_marker.pop("color", None) + kwargs.pop("marker", None) + kwargs_trace = kwargs + + if "size" not in kwargs_marker: + if ggplot2: + kwargs_marker["size"] = 5 + else: + kwargs_marker["size"] = 8 + + if "opacity" not in kwargs_marker: + if not ggplot2: + kwargs_trace["opacity"] = 0.6 + + if "line" not in kwargs_marker: + if not ggplot2: + kwargs_marker["line"] = {"color": "darkgrey", "width": 1} + else: + kwargs_marker["line"] = {} + + # default marker size + if not ggplot2: + if not marker_color: + marker_color = "rgb(31, 119, 180)" + else: + marker_color = "rgb(0, 0, 0)" + + num_of_rows = 1 + num_of_cols = 1 + flipped_rows = False + flipped_cols = False + if facet_row: + num_of_rows = len(df[facet_row].unique()) + flipped_rows = _is_flipped(num_of_rows) + if isinstance(facet_row_labels, dict): + for key in df[facet_row].unique(): + if key not in facet_row_labels.keys(): + unique_keys = df[facet_row].unique().tolist() + raise exceptions.PlotlyError(CUSTOM_LABEL_ERROR.format(unique_keys)) + if facet_col: + num_of_cols = len(df[facet_col].unique()) + flipped_cols = _is_flipped(num_of_cols) + if isinstance(facet_col_labels, dict): + for key in df[facet_col].unique(): + if key not in facet_col_labels.keys(): + unique_keys = df[facet_col].unique().tolist() + raise exceptions.PlotlyError(CUSTOM_LABEL_ERROR.format(unique_keys)) + show_legend = False + if color_name: + if isinstance(df[color_name].iloc[0], str) or color_is_cat: + show_legend = True + if isinstance(colormap, dict): + clrs.validate_colors_dict(colormap, "rgb") + + for val in df[color_name].unique(): + if val not in colormap.keys(): + raise exceptions.PlotlyError( + "If using 'colormap' as a dictionary, make sure " + "all the values of the colormap column are in " + "the keys of your dictionary." + ) + else: + # use default plotly colors for dictionary + default_colors = clrs.DEFAULT_PLOTLY_COLORS + colormap = {} + j = 0 + for val in df[color_name].unique(): + if j >= len(default_colors): + j = 0 + colormap[val] = default_colors[j] + j += 1 + fig, annotations = _facet_grid_color_categorical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colormap, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, + ) + + elif isinstance(df[color_name].iloc[0], Number): + if isinstance(colormap, dict): + show_legend = True + clrs.validate_colors_dict(colormap, "rgb") + + for val in df[color_name].unique(): + if val not in colormap.keys(): + raise exceptions.PlotlyError( + "If using 'colormap' as a dictionary, make sure " + "all the values of the colormap column are in " + "the keys of your dictionary." + ) + fig, annotations = _facet_grid_color_categorical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colormap, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, + ) + + elif isinstance(colormap, list): + colorscale_list = colormap + clrs.validate_colorscale(colorscale_list) + + fig, annotations = _facet_grid_color_numerical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colorscale_list, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, + ) + elif isinstance(colormap, str): + if colormap in clrs.PLOTLY_SCALES.keys(): + colorscale_list = clrs.PLOTLY_SCALES[colormap] + else: + raise exceptions.PlotlyError( + "If 'colormap' is a string, it must be the name " + "of a Plotly Colorscale. The available colorscale " + "names are {}".format(clrs.PLOTLY_SCALES.keys()) + ) + fig, annotations = _facet_grid_color_numerical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colorscale_list, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, + ) + else: + colorscale_list = clrs.PLOTLY_SCALES["Reds"] + fig, annotations = _facet_grid_color_numerical( + df, + x, + y, + facet_row, + facet_col, + color_name, + colorscale_list, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, + ) + + else: + fig, annotations = _facet_grid( + df, + x, + y, + facet_row, + facet_col, + num_of_rows, + num_of_cols, + facet_row_labels, + facet_col_labels, + trace_type, + flipped_rows, + flipped_cols, + show_boxes, + SUBPLOT_SPACING, + marker_color, + kwargs_trace, + kwargs_marker, + ) + + if not height: + height = max(600, 100 * num_of_rows) + if not width: + width = max(600, 100 * num_of_cols) + + fig["layout"].update( + height=height, width=width, title="", paper_bgcolor="rgb(251, 251, 251)" + ) + if ggplot2: + fig["layout"].update( + plot_bgcolor=PLOT_BGCOLOR, + paper_bgcolor="rgb(255, 255, 255)", + hovermode="closest", + ) + + # axis titles + x_title_annot = _axis_title_annotation(x, "x") + y_title_annot = _axis_title_annotation(y, "y") + + # annotations + annotations.append(x_title_annot) + annotations.append(y_title_annot) + + # legend + fig["layout"]["showlegend"] = show_legend + fig["layout"]["legend"]["bgcolor"] = LEGEND_COLOR + fig["layout"]["legend"]["borderwidth"] = LEGEND_BORDER_WIDTH + fig["layout"]["legend"]["x"] = 1.05 + fig["layout"]["legend"]["y"] = 1 + fig["layout"]["legend"]["yanchor"] = "top" + + if show_legend: + fig["layout"]["showlegend"] = show_legend + if ggplot2: + if color_name: + legend_annot = _legend_annotation(color_name) + annotations.append(legend_annot) + fig["layout"]["margin"]["r"] = 150 + + # assign annotations to figure + fig["layout"]["annotations"] = annotations + + # add shaded boxes behind axis titles + if show_boxes and ggplot2: + _add_shapes_to_fig(fig, ANNOT_RECT_COLOR, flipped_rows, flipped_cols) + + # all xaxis and yaxis labels + axis_labels = {"x": [], "y": []} + for key in fig["layout"]: + if "xaxis" in key: + axis_labels["x"].append(key) + elif "yaxis" in key: + axis_labels["y"].append(key) + + string_number_in_data = False + for var in [v for v in [x, y] if v]: + if isinstance(df[var].tolist()[0], str): + for item in df[var]: + try: + int(item) + string_number_in_data = True + except ValueError: + pass + + if string_number_in_data: + for x_y in axis_labels.keys(): + for axis_name in axis_labels[x_y]: + fig["layout"][axis_name]["type"] = "category" + + if scales == "fixed": + fixed_axes = ["x", "y"] + elif scales == "free_x": + fixed_axes = ["y"] + elif scales == "free_y": + fixed_axes = ["x"] + elif scales == "free": + fixed_axes = [] + + # fixed ranges + for x_y in fixed_axes: + min_ranges = [] + max_ranges = [] + for trace in fig["data"]: + if trace[x_y] is not None and len(trace[x_y]) > 0: + min_ranges.append(min(trace[x_y])) + max_ranges.append(max(trace[x_y])) + while None in min_ranges: + min_ranges.remove(None) + while None in max_ranges: + max_ranges.remove(None) + + min_range = min(min_ranges) + max_range = max(max_ranges) + + range_are_numbers = isinstance(min_range, Number) and isinstance( + max_range, Number + ) + + if range_are_numbers: + min_range = math.floor(min_range) + max_range = math.ceil(max_range) + + # extend widen frame by 5% on each side + min_range -= 0.05 * (max_range - min_range) + max_range += 0.05 * (max_range - min_range) + + if x_y == "x": + if dtick_x: + dtick = dtick_x + else: + dtick = math.floor((max_range - min_range) / MAX_TICKS_PER_AXIS) + elif x_y == "y": + if dtick_y: + dtick = dtick_y + else: + dtick = math.floor((max_range - min_range) / MAX_TICKS_PER_AXIS) + else: + dtick = 1 + + for axis_title in axis_labels[x_y]: + fig["layout"][axis_title]["dtick"] = dtick + fig["layout"][axis_title]["ticklen"] = 0 + fig["layout"][axis_title]["zeroline"] = False + if ggplot2: + fig["layout"][axis_title]["tickwidth"] = 1 + fig["layout"][axis_title]["ticklen"] = 4 + fig["layout"][axis_title]["gridwidth"] = GRID_WIDTH + + fig["layout"][axis_title]["gridcolor"] = GRID_COLOR + fig["layout"][axis_title]["gridwidth"] = 2 + fig["layout"][axis_title]["tickfont"] = { + "color": TICK_COLOR, + "size": 10, + } + + # insert ranges into fig + if x_y in fixed_axes: + for key in fig["layout"]: + if "{}axis".format(x_y) in key and range_are_numbers: + fig["layout"][key]["range"] = [min_range, max_range] + + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py new file mode 100644 index 0000000..2fe393f --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_gantt.py @@ -0,0 +1,1034 @@ +from numbers import Number + +import copy + +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.figure_factory import utils +import plotly.graph_objects as go + +pd = optional_imports.get_module("pandas") + +REQUIRED_GANTT_KEYS = ["Task", "Start", "Finish"] + + +def _get_corner_points(x0, y0, x1, y1): + """ + Returns the corner points of a scatter rectangle + + :param x0: x-start + :param y0: y-lower + :param x1: x-end + :param y1: y-upper + :return: ([x], [y]), tuple of lists containing the x and y values + """ + + return ([x0, x1, x1, x0], [y0, y0, y1, y1]) + + +def validate_gantt(df): + """ + Validates the inputted dataframe or list + """ + if pd and isinstance(df, pd.core.frame.DataFrame): + # validate that df has all the required keys + for key in REQUIRED_GANTT_KEYS: + if key not in df: + raise exceptions.PlotlyError( + "The columns in your dataframe must include the " + "following keys: {0}".format(", ".join(REQUIRED_GANTT_KEYS)) + ) + + num_of_rows = len(df.index) + chart = [] + for index in range(num_of_rows): + task_dict = {} + for key in df: + task_dict[key] = df.iloc[index][key] + chart.append(task_dict) + + return chart + + # validate if df is a list + if not isinstance(df, list): + raise exceptions.PlotlyError( + "You must input either a dataframe or a list of dictionaries." + ) + + # validate if df is empty + if len(df) <= 0: + raise exceptions.PlotlyError( + "Your list is empty. It must contain at least one dictionary." + ) + if not isinstance(df[0], dict): + raise exceptions.PlotlyError("Your list must only include dictionaries.") + return df + + +def gantt( + chart, + colors, + title, + bar_width, + showgrid_x, + showgrid_y, + height, + width, + tasks=None, + task_names=None, + data=None, + group_tasks=False, + show_hover_fill=True, + show_colorbar=True, +): + """ + Refer to create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + + for index in range(len(chart)): + task = dict( + x0=chart[index]["Start"], + x1=chart[index]["Finish"], + name=chart[index]["Task"], + ) + if "Description" in chart[index]: + task["description"] = chart[index]["Description"] + tasks.append(task) + + # create a scatter trace for every task group + scatter_data_dict = dict() + marker_data_dict = dict() + + if show_hover_fill: + hoverinfo = "name" + else: + hoverinfo = "skip" + + scatter_data_template = { + "x": [], + "y": [], + "mode": "none", + "fill": "toself", + "hoverinfo": hoverinfo, + } + + marker_data_template = { + "x": [], + "y": [], + "mode": "markers", + "text": [], + "marker": dict(color="", size=1, opacity=0), + "name": "", + "showlegend": False, + } + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]["name"] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted first + # are shown at the top + if group_tasks: + task_names.reverse() + + color_index = 0 + for index in range(len(tasks)): + tn = tasks[index]["name"] + del tasks[index]["name"] + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]["y0"] = groupID - bar_width + tasks[index]["y1"] = groupID + bar_width + + # check if colors need to be looped + if color_index >= len(colors): + color_index = 0 + tasks[index]["fillcolor"] = colors[color_index] + color_id = tasks[index]["fillcolor"] + + if color_id not in scatter_data_dict: + scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template) + + scatter_data_dict[color_id]["fillcolor"] = color_id + scatter_data_dict[color_id]["name"] = str(tn) + scatter_data_dict[color_id]["legendgroup"] = color_id + + # if there are already values append the gap + if len(scatter_data_dict[color_id]["x"]) > 0: + # a gap on the scatterplot separates the rectangles from each other + scatter_data_dict[color_id]["x"].append( + scatter_data_dict[color_id]["x"][-1] + ) + scatter_data_dict[color_id]["y"].append(None) + + xs, ys = _get_corner_points( + tasks[index]["x0"], + tasks[index]["y0"], + tasks[index]["x1"], + tasks[index]["y1"], + ) + + scatter_data_dict[color_id]["x"] += xs + scatter_data_dict[color_id]["y"] += ys + + # append dummy markers for showing start and end of interval + if color_id not in marker_data_dict: + marker_data_dict[color_id] = copy.deepcopy(marker_data_template) + marker_data_dict[color_id]["marker"]["color"] = color_id + marker_data_dict[color_id]["legendgroup"] = color_id + + marker_data_dict[color_id]["x"].append(tasks[index]["x0"]) + marker_data_dict[color_id]["x"].append(tasks[index]["x1"]) + marker_data_dict[color_id]["y"].append(groupID) + marker_data_dict[color_id]["y"].append(groupID) + + if "description" in tasks[index]: + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + del tasks[index]["description"] + else: + marker_data_dict[color_id]["text"].append(None) + marker_data_dict[color_id]["text"].append(None) + + color_index += 1 + + showlegend = show_colorbar + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode="closest", + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list( + [ + dict(count=7, label="1w", step="day", stepmode="backward"), + dict(count=1, label="1m", step="month", stepmode="backward"), + dict(count=6, label="6m", step="month", stepmode="backward"), + dict(count=1, label="YTD", step="year", stepmode="todate"), + dict(count=1, label="1y", step="year", stepmode="backward"), + dict(step="all"), + ] + ) + ), + type="date", + ), + ) + + data = [scatter_data_dict[k] for k in sorted(scatter_data_dict)] + data += [marker_data_dict[k] for k in sorted(marker_data_dict)] + + # fig = dict( + # data=data, layout=layout + # ) + fig = go.Figure(data=data, layout=layout) + return fig + + +def gantt_colorscale( + chart, + colors, + title, + index_col, + show_colorbar, + bar_width, + showgrid_x, + showgrid_y, + height, + width, + tasks=None, + task_names=None, + data=None, + group_tasks=False, + show_hover_fill=True, +): + """ + Refer to FigureFactory.create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + showlegend = False + + for index in range(len(chart)): + task = dict( + x0=chart[index]["Start"], + x1=chart[index]["Finish"], + name=chart[index]["Task"], + ) + if "Description" in chart[index]: + task["description"] = chart[index]["Description"] + tasks.append(task) + + # create a scatter trace for every task group + scatter_data_dict = dict() + # create scatter traces for the start- and endpoints + marker_data_dict = dict() + + if show_hover_fill: + hoverinfo = "name" + else: + hoverinfo = "skip" + + scatter_data_template = { + "x": [], + "y": [], + "mode": "none", + "fill": "toself", + "showlegend": False, + "hoverinfo": hoverinfo, + "legendgroup": "", + } + + marker_data_template = { + "x": [], + "y": [], + "mode": "markers", + "text": [], + "marker": dict(color="", size=1, opacity=0), + "name": "", + "showlegend": False, + "legendgroup": "", + } + + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + # compute the color for task based on indexing column + if isinstance(chart[0][index_col], Number): + # check that colors has at least 2 colors + if len(colors) < 2: + raise exceptions.PlotlyError( + "You must use at least 2 colors in 'colors' if you " + "are using a colorscale. However only the first two " + "colors given will be used for the lower and upper " + "bounds on the colormap." + ) + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]["name"] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted + # first are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]["name"] + del tasks[index]["name"] + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]["y0"] = groupID - bar_width + tasks[index]["y1"] = groupID + bar_width + + # unlabel color + colors = clrs.color_parser(colors, clrs.unlabel_rgb) + lowcolor = colors[0] + highcolor = colors[1] + + intermed = (chart[index][index_col]) / 100.0 + intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed) + intermed_color = clrs.color_parser(intermed_color, clrs.label_rgb) + tasks[index]["fillcolor"] = intermed_color + color_id = tasks[index]["fillcolor"] + + if color_id not in scatter_data_dict: + scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template) + + scatter_data_dict[color_id]["fillcolor"] = color_id + scatter_data_dict[color_id]["name"] = str(chart[index][index_col]) + scatter_data_dict[color_id]["legendgroup"] = color_id + + # relabel colors with 'rgb' + colors = clrs.color_parser(colors, clrs.label_rgb) + + # if there are already values append the gap + if len(scatter_data_dict[color_id]["x"]) > 0: + # a gap on the scatterplot separates the rectangles from each other + scatter_data_dict[color_id]["x"].append( + scatter_data_dict[color_id]["x"][-1] + ) + scatter_data_dict[color_id]["y"].append(None) + + xs, ys = _get_corner_points( + tasks[index]["x0"], + tasks[index]["y0"], + tasks[index]["x1"], + tasks[index]["y1"], + ) + + scatter_data_dict[color_id]["x"] += xs + scatter_data_dict[color_id]["y"] += ys + + # append dummy markers for showing start and end of interval + if color_id not in marker_data_dict: + marker_data_dict[color_id] = copy.deepcopy(marker_data_template) + marker_data_dict[color_id]["marker"]["color"] = color_id + marker_data_dict[color_id]["legendgroup"] = color_id + + marker_data_dict[color_id]["x"].append(tasks[index]["x0"]) + marker_data_dict[color_id]["x"].append(tasks[index]["x1"]) + marker_data_dict[color_id]["y"].append(groupID) + marker_data_dict[color_id]["y"].append(groupID) + + if "description" in tasks[index]: + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + del tasks[index]["description"] + else: + marker_data_dict[color_id]["text"].append(None) + marker_data_dict[color_id]["text"].append(None) + + # add colorbar to one of the traces randomly just for display + if show_colorbar is True: + k = list(marker_data_dict.keys())[0] + marker_data_dict[k]["marker"].update( + dict( + colorscale=[[0, colors[0]], [1, colors[1]]], + showscale=True, + cmax=100, + cmin=0, + ) + ) + + if isinstance(chart[0][index_col], str): + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + if len(colors) < len(index_vals): + raise exceptions.PlotlyError( + "Error. The number of colors in 'colors' must be no less " + "than the number of unique index values in your group " + "column." + ) + + # make a dictionary assignment to each index value + index_vals_dict = {} + # define color index + c_index = 0 + for key in index_vals: + if c_index > len(colors) - 1: + c_index = 0 + index_vals_dict[key] = colors[c_index] + c_index += 1 + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]["name"] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted + # first are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]["name"] + del tasks[index]["name"] + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]["y0"] = groupID - bar_width + tasks[index]["y1"] = groupID + bar_width + + tasks[index]["fillcolor"] = index_vals_dict[chart[index][index_col]] + color_id = tasks[index]["fillcolor"] + + if color_id not in scatter_data_dict: + scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template) + + scatter_data_dict[color_id]["fillcolor"] = color_id + scatter_data_dict[color_id]["legendgroup"] = color_id + scatter_data_dict[color_id]["name"] = str(chart[index][index_col]) + + # relabel colors with 'rgb' + colors = clrs.color_parser(colors, clrs.label_rgb) + + # if there are already values append the gap + if len(scatter_data_dict[color_id]["x"]) > 0: + # a gap on the scatterplot separates the rectangles from each other + scatter_data_dict[color_id]["x"].append( + scatter_data_dict[color_id]["x"][-1] + ) + scatter_data_dict[color_id]["y"].append(None) + + xs, ys = _get_corner_points( + tasks[index]["x0"], + tasks[index]["y0"], + tasks[index]["x1"], + tasks[index]["y1"], + ) + + scatter_data_dict[color_id]["x"] += xs + scatter_data_dict[color_id]["y"] += ys + + # append dummy markers for showing start and end of interval + if color_id not in marker_data_dict: + marker_data_dict[color_id] = copy.deepcopy(marker_data_template) + marker_data_dict[color_id]["marker"]["color"] = color_id + marker_data_dict[color_id]["legendgroup"] = color_id + + marker_data_dict[color_id]["x"].append(tasks[index]["x0"]) + marker_data_dict[color_id]["x"].append(tasks[index]["x1"]) + marker_data_dict[color_id]["y"].append(groupID) + marker_data_dict[color_id]["y"].append(groupID) + + if "description" in tasks[index]: + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + del tasks[index]["description"] + else: + marker_data_dict[color_id]["text"].append(None) + marker_data_dict[color_id]["text"].append(None) + + if show_colorbar is True: + showlegend = True + for k in scatter_data_dict: + scatter_data_dict[k]["showlegend"] = showlegend + # add colorbar to one of the traces randomly just for display + # if show_colorbar is True: + # k = list(marker_data_dict.keys())[0] + # marker_data_dict[k]["marker"].update( + # dict( + # colorscale=[[0, colors[0]], [1, colors[1]]], + # showscale=True, + # cmax=100, + # cmin=0, + # ) + # ) + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode="closest", + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list( + [ + dict(count=7, label="1w", step="day", stepmode="backward"), + dict(count=1, label="1m", step="month", stepmode="backward"), + dict(count=6, label="6m", step="month", stepmode="backward"), + dict(count=1, label="YTD", step="year", stepmode="todate"), + dict(count=1, label="1y", step="year", stepmode="backward"), + dict(step="all"), + ] + ) + ), + type="date", + ), + ) + + data = [scatter_data_dict[k] for k in sorted(scatter_data_dict)] + data += [marker_data_dict[k] for k in sorted(marker_data_dict)] + + # fig = dict( + # data=data, layout=layout + # ) + fig = go.Figure(data=data, layout=layout) + return fig + + +def gantt_dict( + chart, + colors, + title, + index_col, + show_colorbar, + bar_width, + showgrid_x, + showgrid_y, + height, + width, + tasks=None, + task_names=None, + data=None, + group_tasks=False, + show_hover_fill=True, +): + """ + Refer to FigureFactory.create_gantt() for docstring + """ + + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + showlegend = False + + for index in range(len(chart)): + task = dict( + x0=chart[index]["Start"], + x1=chart[index]["Finish"], + name=chart[index]["Task"], + ) + if "Description" in chart[index]: + task["description"] = chart[index]["Description"] + tasks.append(task) + + # create a scatter trace for every task group + scatter_data_dict = dict() + # create scatter traces for the start- and endpoints + marker_data_dict = dict() + + if show_hover_fill: + hoverinfo = "name" + else: + hoverinfo = "skip" + + scatter_data_template = { + "x": [], + "y": [], + "mode": "none", + "fill": "toself", + "hoverinfo": hoverinfo, + "legendgroup": "", + } + + marker_data_template = { + "x": [], + "y": [], + "mode": "markers", + "text": [], + "marker": dict(color="", size=1, opacity=0), + "name": "", + "showlegend": False, + } + + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + # verify each value in index column appears in colors dictionary + for key in index_vals: + if key not in colors: + raise exceptions.PlotlyError( + "If you are using colors as a dictionary, all of its " + "keys must be all the values in the index column." + ) + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]["name"] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted first + # are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]["name"] + del tasks[index]["name"] + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]["y0"] = groupID - bar_width + tasks[index]["y1"] = groupID + bar_width + + tasks[index]["fillcolor"] = colors[chart[index][index_col]] + color_id = tasks[index]["fillcolor"] + + if color_id not in scatter_data_dict: + scatter_data_dict[color_id] = copy.deepcopy(scatter_data_template) + + scatter_data_dict[color_id]["legendgroup"] = color_id + scatter_data_dict[color_id]["fillcolor"] = color_id + + # if there are already values append the gap + if len(scatter_data_dict[color_id]["x"]) > 0: + # a gap on the scatterplot separates the rectangles from each other + scatter_data_dict[color_id]["x"].append( + scatter_data_dict[color_id]["x"][-1] + ) + scatter_data_dict[color_id]["y"].append(None) + + xs, ys = _get_corner_points( + tasks[index]["x0"], + tasks[index]["y0"], + tasks[index]["x1"], + tasks[index]["y1"], + ) + + scatter_data_dict[color_id]["x"] += xs + scatter_data_dict[color_id]["y"] += ys + + # append dummy markers for showing start and end of interval + if color_id not in marker_data_dict: + marker_data_dict[color_id] = copy.deepcopy(marker_data_template) + marker_data_dict[color_id]["marker"]["color"] = color_id + marker_data_dict[color_id]["legendgroup"] = color_id + + marker_data_dict[color_id]["x"].append(tasks[index]["x0"]) + marker_data_dict[color_id]["x"].append(tasks[index]["x1"]) + marker_data_dict[color_id]["y"].append(groupID) + marker_data_dict[color_id]["y"].append(groupID) + + if "description" in tasks[index]: + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + marker_data_dict[color_id]["text"].append(tasks[index]["description"]) + del tasks[index]["description"] + else: + marker_data_dict[color_id]["text"].append(None) + marker_data_dict[color_id]["text"].append(None) + + if show_colorbar is True: + showlegend = True + + for index_value in index_vals: + scatter_data_dict[colors[index_value]]["name"] = str(index_value) + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode="closest", + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list( + [ + dict(count=7, label="1w", step="day", stepmode="backward"), + dict(count=1, label="1m", step="month", stepmode="backward"), + dict(count=6, label="6m", step="month", stepmode="backward"), + dict(count=1, label="YTD", step="year", stepmode="todate"), + dict(count=1, label="1y", step="year", stepmode="backward"), + dict(step="all"), + ] + ) + ), + type="date", + ), + ) + + data = [scatter_data_dict[k] for k in sorted(scatter_data_dict)] + data += [marker_data_dict[k] for k in sorted(marker_data_dict)] + + # fig = dict( + # data=data, layout=layout + # ) + fig = go.Figure(data=data, layout=layout) + return fig + + +def create_gantt( + df, + colors=None, + index_col=None, + show_colorbar=False, + reverse_colors=False, + title="Gantt Chart", + bar_width=0.2, + showgrid_x=False, + showgrid_y=False, + height=600, + width=None, + tasks=None, + task_names=None, + data=None, + group_tasks=False, + show_hover_fill=True, +): + """ + **deprecated**, use instead + :func:`plotly.express.timeline`. + + Returns figure for a gantt chart + + :param (array|list) df: input data for gantt chart. Must be either a + a dataframe or a list. If dataframe, the columns must include + 'Task', 'Start' and 'Finish'. Other columns can be included and + used for indexing. If a list, its elements must be dictionaries + with the same required column headers: 'Task', 'Start' and + 'Finish'. + :param (str|list|dict|tuple) colors: either a plotly scale name, an + rgb or hex color, a color tuple or a list of colors. An rgb color + is of the form 'rgb(x, y, z)' where x, y, z belong to the interval + [0, 255] and a color tuple is a tuple of the form (a, b, c) where + a, b and c belong to [0, 1]. If colors is a list, it must + contain the valid color types aforementioned as its members. + If a dictionary, all values of the indexing column must be keys in + colors. + :param (str|float) index_col: the column header (if df is a data + frame) that will function as the indexing column. If df is a list, + index_col must be one of the keys in all the items of df. + :param (bool) show_colorbar: determines if colorbar will be visible. + Only applies if values in the index column are numeric. + :param (bool) show_hover_fill: enables/disables the hovertext for the + filled area of the chart. + :param (bool) reverse_colors: reverses the order of selected colors + :param (str) title: the title of the chart + :param (float) bar_width: the width of the horizontal bars in the plot + :param (bool) showgrid_x: show/hide the x-axis grid + :param (bool) showgrid_y: show/hide the y-axis grid + :param (float) height: the height of the chart + :param (float) width: the width of the chart + + Example 1: Simple Gantt Chart + + >>> from plotly.figure_factory import create_gantt + + >>> # Make data for chart + >>> df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'), + ... dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'), + ... dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')] + + >>> # Create a figure + >>> fig = create_gantt(df) + >>> fig.show() + + + Example 2: Index by Column with Numerical Entries + + >>> from plotly.figure_factory import create_gantt + + >>> # Make data for chart + >>> df = [dict(Task="Job A", Start='2009-01-01', + ... Finish='2009-02-30', Complete=10), + ... dict(Task="Job B", Start='2009-03-05', + ... Finish='2009-04-15', Complete=60), + ... dict(Task="Job C", Start='2009-02-20', + ... Finish='2009-05-30', Complete=95)] + + >>> # Create a figure with Plotly colorscale + >>> fig = create_gantt(df, colors='Blues', index_col='Complete', + ... show_colorbar=True, bar_width=0.5, + ... showgrid_x=True, showgrid_y=True) + >>> fig.show() + + + Example 3: Index by Column with String Entries + + >>> from plotly.figure_factory import create_gantt + + >>> # Make data for chart + >>> df = [dict(Task="Job A", Start='2009-01-01', + ... Finish='2009-02-30', Resource='Apple'), + ... dict(Task="Job B", Start='2009-03-05', + ... Finish='2009-04-15', Resource='Grape'), + ... dict(Task="Job C", Start='2009-02-20', + ... Finish='2009-05-30', Resource='Banana')] + + >>> # Create a figure with Plotly colorscale + >>> fig = create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'], + ... index_col='Resource', reverse_colors=True, + ... show_colorbar=True) + >>> fig.show() + + + Example 4: Use a dictionary for colors + + >>> from plotly.figure_factory import create_gantt + >>> # Make data for chart + >>> df = [dict(Task="Job A", Start='2009-01-01', + ... Finish='2009-02-30', Resource='Apple'), + ... dict(Task="Job B", Start='2009-03-05', + ... Finish='2009-04-15', Resource='Grape'), + ... dict(Task="Job C", Start='2009-02-20', + ... Finish='2009-05-30', Resource='Banana')] + + >>> # Make a dictionary of colors + >>> colors = {'Apple': 'rgb(255, 0, 0)', + ... 'Grape': 'rgb(170, 14, 200)', + ... 'Banana': (1, 1, 0.2)} + + >>> # Create a figure with Plotly colorscale + >>> fig = create_gantt(df, colors=colors, index_col='Resource', + ... show_colorbar=True) + + >>> fig.show() + + Example 5: Use a pandas dataframe + + >>> from plotly.figure_factory import create_gantt + >>> import pandas as pd + + >>> # Make data as a dataframe + >>> df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10], + ... ['Fast', '2011-01-01', '2012-06-05', 55], + ... ['Eat', '2012-01-05', '2013-07-05', 94]], + ... columns=['Task', 'Start', 'Finish', 'Complete']) + + >>> # Create a figure with Plotly colorscale + >>> fig = create_gantt(df, colors='Blues', index_col='Complete', + ... show_colorbar=True, bar_width=0.5, + ... showgrid_x=True, showgrid_y=True) + >>> fig.show() + """ + # validate gantt input data + chart = validate_gantt(df) + + if index_col: + if index_col not in chart[0]: + raise exceptions.PlotlyError( + "In order to use an indexing column and assign colors to " + "the values of the index, you must choose an actual " + "column name in the dataframe or key if a list of " + "dictionaries is being used." + ) + + # validate gantt index column + index_list = [] + for dictionary in chart: + index_list.append(dictionary[index_col]) + utils.validate_index(index_list) + + # Validate colors + if isinstance(colors, dict): + colors = clrs.validate_colors_dict(colors, "rgb") + else: + colors = clrs.validate_colors(colors, "rgb") + + if reverse_colors is True: + colors.reverse() + + if not index_col: + if isinstance(colors, dict): + raise exceptions.PlotlyError( + "Error. You have set colors to a dictionary but have not " + "picked an index. An index is required if you are " + "assigning colors to particular values in a dictionary." + ) + fig = gantt( + chart, + colors, + title, + bar_width, + showgrid_x, + showgrid_y, + height, + width, + tasks=None, + task_names=None, + data=None, + group_tasks=group_tasks, + show_hover_fill=show_hover_fill, + show_colorbar=show_colorbar, + ) + return fig + else: + if not isinstance(colors, dict): + fig = gantt_colorscale( + chart, + colors, + title, + index_col, + show_colorbar, + bar_width, + showgrid_x, + showgrid_y, + height, + width, + tasks=None, + task_names=None, + data=None, + group_tasks=group_tasks, + show_hover_fill=show_hover_fill, + ) + return fig + else: + fig = gantt_dict( + chart, + colors, + title, + index_col, + show_colorbar, + bar_width, + showgrid_x, + showgrid_y, + height, + width, + tasks=None, + task_names=None, + data=None, + group_tasks=group_tasks, + show_hover_fill=show_hover_fill, + ) + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py new file mode 100644 index 0000000..c763522 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_hexbin_mapbox.py @@ -0,0 +1,526 @@ +from plotly.express._core import build_dataframe +from plotly.express._doc import make_docstring +from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox +import narwhals.stable.v1 as nw +import numpy as np + + +def _project_latlon_to_wgs84(lat, lon): + """ + Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map + """ + x = lon * np.pi / 180 + y = np.arctanh(np.sin(lat * np.pi / 180)) + return x, y + + +def _project_wgs84_to_latlon(x, y): + """ + Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map + """ + lon = x * 180 / np.pi + lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi + return lat, lon + + +def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim): + """ + Get the mapbox zoom level given bounds and a figure dimension + Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds + """ + + scale = ( + 2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256 + ) + WORLD_DIM = {"height": 256 * scale, "width": 256 * scale} + ZOOM_MAX = 18 + + def latRad(lat): + sin = np.sin(lat * np.pi / 180) + radX2 = np.log((1 + sin) / (1 - sin)) / 2 + return max(min(radX2, np.pi), -np.pi) / 2 + + def zoom(mapPx, worldPx, fraction): + return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2) + + latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi + + lngDiff = lon_max - lon_min + lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360 + + latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction) + lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction) + + return min(latZoom, lngZoom, ZOOM_MAX) + + +def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count): + """ + Computes the aggregation at hexagonal bin level. + Also defines the coordinates of the hexagons for plotting. + The binning is inspired by matplotlib's implementation. + + Parameters + ---------- + x : np.ndarray + Array of x values (shape N) + y : np.ndarray + Array of y values (shape N) + x_range : np.ndarray + Min and max x (shape 2) + y_range : np.ndarray + Min and max y (shape 2) + color : np.ndarray + Metric to aggregate at hexagon level (shape N) + nx : int + Number of hexagons horizontally + agg_func : function + Numpy compatible aggregator, this function must take a one-dimensional + np.ndarray as input and output a scalar + min_count : int + Minimum number of points in the hexagon for the hexagon to be displayed + + Returns + ------- + np.ndarray + X coordinates of each hexagon (shape M x 6) + np.ndarray + Y coordinates of each hexagon (shape M x 6) + np.ndarray + Centers of the hexagons (shape M x 2) + np.ndarray + Aggregated value in each hexagon (shape M) + + """ + xmin = x_range.min() + xmax = x_range.max() + ymin = y_range.min() + ymax = y_range.max() + + # In the x-direction, the hexagons exactly cover the region from + # xmin to xmax. Need some padding to avoid roundoff errors. + padding = 1.0e-9 * (xmax - xmin) + xmin -= padding + xmax += padding + + Dx = xmax - xmin + Dy = ymax - ymin + if Dx == 0 and Dy > 0: + dx = Dy / nx + elif Dx == 0 and Dy == 0: + dx, _ = _project_latlon_to_wgs84(1, 1) + else: + dx = Dx / nx + dy = dx * np.sqrt(3) + ny = np.ceil(Dy / dy).astype(int) + + # Center the hexagons vertically since we only want regular hexagons + ymin -= (ymin + dy * ny - ymax) / 2 + + x = (x - xmin) / dx + y = (y - ymin) / dy + ix1 = np.round(x).astype(int) + iy1 = np.round(y).astype(int) + ix2 = np.floor(x).astype(int) + iy2 = np.floor(y).astype(int) + + nx1 = nx + 1 + ny1 = ny + 1 + nx2 = nx + ny2 = ny + n = nx1 * ny1 + nx2 * ny2 + + d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2 + d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2 + bdist = d1 < d2 + + if color is None: + lattice1 = np.zeros((nx1, ny1)) + lattice2 = np.zeros((nx2, ny2)) + c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist + c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist + np.add.at(lattice1, (ix1[c1], iy1[c1]), 1) + np.add.at(lattice2, (ix2[c2], iy2[c2]), 1) + if min_count is not None: + lattice1[lattice1 < min_count] = np.nan + lattice2[lattice2 < min_count] = np.nan + accum = np.concatenate([lattice1.ravel(), lattice2.ravel()]) + good_idxs = ~np.isnan(accum) + else: + if min_count is None: + min_count = 1 + + # create accumulation arrays + lattice1 = np.empty((nx1, ny1), dtype=object) + for i in range(nx1): + for j in range(ny1): + lattice1[i, j] = [] + lattice2 = np.empty((nx2, ny2), dtype=object) + for i in range(nx2): + for j in range(ny2): + lattice2[i, j] = [] + + for i in range(len(x)): + if bdist[i]: + if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1: + lattice1[ix1[i], iy1[i]].append(color[i]) + else: + if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2: + lattice2[ix2[i], iy2[i]].append(color[i]) + + for i in range(nx1): + for j in range(ny1): + vals = lattice1[i, j] + if len(vals) >= min_count: + lattice1[i, j] = agg_func(vals) + else: + lattice1[i, j] = np.nan + for i in range(nx2): + for j in range(ny2): + vals = lattice2[i, j] + if len(vals) >= min_count: + lattice2[i, j] = agg_func(vals) + else: + lattice2[i, j] = np.nan + + accum = np.hstack( + (lattice1.astype(float).ravel(), lattice2.astype(float).ravel()) + ) + good_idxs = ~np.isnan(accum) + + agreggated_value = accum[good_idxs] + + centers = np.zeros((n, 2), float) + centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1) + centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1) + centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2) + centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5 + centers[:, 0] *= dx + centers[:, 1] *= dy + centers[:, 0] += xmin + centers[:, 1] += ymin + centers = centers[good_idxs] + + # Define normalised regular hexagon coordinates + hx = [0, 0.5, 0.5, 0, -0.5, -0.5] + hy = [ + -0.5 / np.cos(np.pi / 6), + -0.5 * np.tan(np.pi / 6), + 0.5 * np.tan(np.pi / 6), + 0.5 / np.cos(np.pi / 6), + 0.5 * np.tan(np.pi / 6), + -0.5 * np.tan(np.pi / 6), + ] + + # Number of hexagons needed + m = len(centers) + + # Coordinates for all hexagonal patches + hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0]) + hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1]) + + return hxs, hys, centers, agreggated_value + + +def _compute_wgs84_hexbin( + lat=None, + lon=None, + lat_range=None, + lon_range=None, + color=None, + nx=None, + agg_func=None, + min_count=None, + native_namespace=None, +): + """ + Computes the lat-lon aggregation at hexagonal bin level. + Latitude and longitude need to be projected to WGS84 before aggregating + in order to display regular hexagons on the map. + + Parameters + ---------- + lat : np.ndarray + Array of latitudes (shape N) + lon : np.ndarray + Array of longitudes (shape N) + lat_range : np.ndarray + Min and max latitudes (shape 2) + lon_range : np.ndarray + Min and max longitudes (shape 2) + color : np.ndarray + Metric to aggregate at hexagon level (shape N) + nx : int + Number of hexagons horizontally + agg_func : function + Numpy compatible aggregator, this function must take a one-dimensional + np.ndarray as input and output a scalar + min_count : int + Minimum number of points in the hexagon for the hexagon to be displayed + + Returns + ------- + np.ndarray + Lat coordinates of each hexagon (shape M x 6) + np.ndarray + Lon coordinates of each hexagon (shape M x 6) + nw.Series + Unique id for each hexagon, to be used in the geojson data (shape M) + np.ndarray + Aggregated value in each hexagon (shape M) + + """ + # Project to WGS 84 + x, y = _project_latlon_to_wgs84(lat, lon) + + if lat_range is None: + lat_range = np.array([lat.min(), lat.max()]) + if lon_range is None: + lon_range = np.array([lon.min(), lon.max()]) + + x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range) + + hxs, hys, centers, agreggated_value = _compute_hexbin( + x, y, x_range, y_range, color, nx, agg_func, min_count + ) + + # Convert back to lat-lon + hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys) + + # Create unique feature id based on hexagon center + centers = centers.astype(str) + hexagons_ids = ( + nw.from_dict( + {"x1": centers[:, 0], "x2": centers[:, 1]}, + native_namespace=native_namespace, + ) + .select(hexagons_ids=nw.concat_str([nw.col("x1"), nw.col("x2")], separator=",")) + .get_column("hexagons_ids") + ) + + return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value + + +def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None): + """ + Creates a geojson of hexagonal features based on the outputs of + _compute_wgs84_hexbin + """ + features = [] + if ids is None: + ids = np.arange(len(hexagons_lats)) + for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids): + points = np.array([lon, lat]).T.tolist() + points.append(points[0]) + features.append( + dict( + type="Feature", + id=idx, + geometry=dict(type="Polygon", coordinates=[points]), + ) + ) + return dict(type="FeatureCollection", features=features) + + +def create_hexbin_mapbox( + data_frame=None, + lat=None, + lon=None, + color=None, + nx_hexagon=5, + agg_func=None, + animation_frame=None, + color_discrete_sequence=None, + color_discrete_map={}, + labels={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + zoom=None, + center=None, + mapbox_style=None, + title=None, + template=None, + width=None, + height=None, + min_count=None, + show_original_data=False, + original_data_marker=None, +): + """ + Returns a figure aggregating scattered points into connected hexagons + """ + args = build_dataframe(args=locals(), constructor=None) + native_namespace = nw.get_native_namespace(args["data_frame"]) + if agg_func is None: + agg_func = np.mean + + lat_range = ( + args["data_frame"] + .select( + nw.min(args["lat"]).name.suffix("_min"), + nw.max(args["lat"]).name.suffix("_max"), + ) + .to_numpy() + .squeeze() + ) + + lon_range = ( + args["data_frame"] + .select( + nw.min(args["lon"]).name.suffix("_min"), + nw.max(args["lon"]).name.suffix("_max"), + ) + .to_numpy() + .squeeze() + ) + + hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin( + lat=args["data_frame"].get_column(args["lat"]).to_numpy(), + lon=args["data_frame"].get_column(args["lon"]).to_numpy(), + lat_range=lat_range, + lon_range=lon_range, + color=None, + nx=nx_hexagon, + agg_func=agg_func, + min_count=min_count, + native_namespace=native_namespace, + ) + + geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids) + + if zoom is None: + if height is None and width is None: + mapDim = dict(height=450, width=450) + elif height is None and width is not None: + mapDim = dict(height=450, width=width) + elif height is not None and width is None: + mapDim = dict(height=height, width=height) + else: + mapDim = dict(height=height, width=width) + zoom = _getBoundsZoomLevel( + lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim + ) + + if center is None: + center = dict(lat=lat_range.mean(), lon=lon_range.mean()) + + if args["animation_frame"] is not None: + groups = dict( + args["data_frame"] + .group_by(args["animation_frame"], drop_null_keys=True) + .__iter__() + ) + else: + groups = {(0,): args["data_frame"]} + + agg_data_frame_list = [] + for key, df in groups.items(): + _, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin( + lat=df.get_column(args["lat"]).to_numpy(), + lon=df.get_column(args["lon"]).to_numpy(), + lat_range=lat_range, + lon_range=lon_range, + color=df.get_column(args["color"]).to_numpy() if args["color"] else None, + nx=nx_hexagon, + agg_func=agg_func, + min_count=min_count, + native_namespace=native_namespace, + ) + agg_data_frame_list.append( + nw.from_dict( + { + "frame": [key[0]] * len(hexagons_ids), + "locations": hexagons_ids, + "color": aggregated_value, + }, + native_namespace=native_namespace, + ) + ) + + agg_data_frame = nw.concat(agg_data_frame_list, how="vertical").with_columns( + color=nw.col("color").cast(nw.Int64) + ) + + if range_color is None: + range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()] + + fig = choropleth_mapbox( + data_frame=agg_data_frame.to_native(), + geojson=geojson, + locations="locations", + color="color", + hover_data={"color": True, "locations": False, "frame": False}, + animation_frame=("frame" if args["animation_frame"] is not None else None), + color_discrete_sequence=color_discrete_sequence, + color_discrete_map=color_discrete_map, + labels=labels, + color_continuous_scale=color_continuous_scale, + range_color=range_color, + color_continuous_midpoint=color_continuous_midpoint, + opacity=opacity, + zoom=zoom, + center=center, + mapbox_style=mapbox_style, + title=title, + template=template, + width=width, + height=height, + ) + + if show_original_data: + original_fig = scatter_mapbox( + data_frame=( + args["data_frame"].sort( + by=args["animation_frame"], descending=False, nulls_last=True + ) + if args["animation_frame"] is not None + else args["data_frame"] + ).to_native(), + lat=args["lat"], + lon=args["lon"], + animation_frame=args["animation_frame"], + ) + original_fig.data[0].hoverinfo = "skip" + original_fig.data[0].hovertemplate = None + original_fig.data[0].marker = original_data_marker + + fig.add_trace(original_fig.data[0]) + + if args["animation_frame"] is not None: + for i in range(len(original_fig.frames)): + original_fig.frames[i].data[0].hoverinfo = "skip" + original_fig.frames[i].data[0].hovertemplate = None + original_fig.frames[i].data[0].marker = original_data_marker + + fig.frames[i].data = [ + fig.frames[i].data[0], + original_fig.frames[i].data[0], + ] + + return fig + + +create_hexbin_mapbox.__doc__ = make_docstring( + create_hexbin_mapbox, + override_dict=dict( + nx_hexagon=["int", "Number of hexagons (horizontally) to be created"], + agg_func=[ + "function", + "Numpy array aggregator, it must take as input a 1D array", + "and output a scalar value.", + ], + min_count=[ + "int", + "Minimum number of points in a hexagon for it to be displayed.", + "If None and color is not set, display all hexagons.", + "If None and color is set, only display hexagons that contain points.", + ], + show_original_data=[ + "bool", + "Whether to show the original data on top of the hexbin aggregation.", + ], + original_data_marker=["dict", "Scattermapbox marker options."], + ), +) diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py new file mode 100644 index 0000000..a0b1b48 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ohlc.py @@ -0,0 +1,295 @@ +from plotly import exceptions +from plotly.graph_objs import graph_objs +from plotly.figure_factory import utils + + +# Default colours for finance charts +_DEFAULT_INCREASING_COLOR = "#3D9970" # http://clrs.cc +_DEFAULT_DECREASING_COLOR = "#FF4136" + + +def validate_ohlc(open, high, low, close, direction, **kwargs): + """ + ohlc and candlestick specific validations + + Specifically, this checks that the high value is the greatest value and + the low value is the lowest value in each unit. + + See FigureFactory.create_ohlc() or FigureFactory.create_candlestick() + for params + + :raises: (PlotlyError) If the high value is not the greatest value in + each unit. + :raises: (PlotlyError) If the low value is not the lowest value in each + unit. + :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing' + """ + for lst in [open, low, close]: + for index in range(len(high)): + if high[index] < lst[index]: + raise exceptions.PlotlyError( + "Oops! Looks like some of " + "your high values are less " + "the corresponding open, " + "low, or close values. " + "Double check that your data " + "is entered in O-H-L-C order" + ) + + for lst in [open, high, close]: + for index in range(len(low)): + if low[index] > lst[index]: + raise exceptions.PlotlyError( + "Oops! Looks like some of " + "your low values are greater " + "than the corresponding high" + ", open, or close values. " + "Double check that your data " + "is entered in O-H-L-C order" + ) + + direction_opts = ("increasing", "decreasing", "both") + if direction not in direction_opts: + raise exceptions.PlotlyError( + "direction must be defined as 'increasing', 'decreasing', or 'both'" + ) + + +def make_increasing_ohlc(open, high, low, close, dates, **kwargs): + """ + Makes increasing ohlc sticks + + _make_increasing_ohlc() and _make_decreasing_ohlc separate the + increasing trace from the decreasing trace so kwargs (such as + color) can be passed separately to increasing or decreasing traces + when direction is set to 'increasing' or 'decreasing' in + FigureFactory.create_candlestick() + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc + sticks. + """ + (flat_increase_x, flat_increase_y, text_increase) = _OHLC( + open, high, low, close, dates + ).get_increase() + + if "name" in kwargs: + showlegend = True + else: + kwargs.setdefault("name", "Increasing") + showlegend = False + + kwargs.setdefault("line", dict(color=_DEFAULT_INCREASING_COLOR, width=1)) + kwargs.setdefault("text", text_increase) + + ohlc_incr = dict( + type="scatter", + x=flat_increase_x, + y=flat_increase_y, + mode="lines", + showlegend=showlegend, + **kwargs, + ) + return ohlc_incr + + +def make_decreasing_ohlc(open, high, low, close, dates, **kwargs): + """ + Makes decreasing ohlc sticks + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc + sticks. + """ + (flat_decrease_x, flat_decrease_y, text_decrease) = _OHLC( + open, high, low, close, dates + ).get_decrease() + + kwargs.setdefault("line", dict(color=_DEFAULT_DECREASING_COLOR, width=1)) + kwargs.setdefault("text", text_decrease) + kwargs.setdefault("showlegend", False) + kwargs.setdefault("name", "Decreasing") + + ohlc_decr = dict( + type="scatter", x=flat_decrease_x, y=flat_decrease_y, mode="lines", **kwargs + ) + return ohlc_decr + + +def create_ohlc(open, high, low, close, dates=None, direction="both", **kwargs): + """ + **deprecated**, use instead the plotly.graph_objects trace + :class:`plotly.graph_objects.Ohlc` + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing + :param (list) dates: list of datetime objects. Default: None + :param (string) direction: direction can be 'increasing', 'decreasing', + or 'both'. When the direction is 'increasing', the returned figure + consists of all units where the close value is greater than the + corresponding open value, and when the direction is 'decreasing', + the returned figure consists of all units where the close value is + less than or equal to the corresponding open value. When the + direction is 'both', both increasing and decreasing units are + returned. Default: 'both' + :param kwargs: kwargs passed through plotly.graph_objs.Scatter. + These kwargs describe other attributes about the ohlc Scatter trace + such as the color or the legend name. For more information on valid + kwargs call help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of an ohlc chart figure. + + Example 1: Simple OHLC chart from a Pandas DataFrame + + >>> from plotly.figure_factory import create_ohlc + >>> from datetime import datetime + + >>> import pandas as pd + >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv') + >>> fig = create_ohlc(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], dates=df.index) + >>> fig.show() + """ + if dates is not None: + utils.validate_equal_length(open, high, low, close, dates) + else: + utils.validate_equal_length(open, high, low, close) + validate_ohlc(open, high, low, close, direction, **kwargs) + + if direction == "increasing": + ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs) + data = [ohlc_incr] + elif direction == "decreasing": + ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs) + data = [ohlc_decr] + else: + ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs) + ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs) + data = [ohlc_incr, ohlc_decr] + + layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest") + + return graph_objs.Figure(data=data, layout=layout) + + +class _OHLC(object): + """ + Refer to FigureFactory.create_ohlc_increase() for docstring. + """ + + def __init__(self, open, high, low, close, dates, **kwargs): + self.open = open + self.high = high + self.low = low + self.close = close + self.empty = [None] * len(open) + self.dates = dates + + self.all_x = [] + self.all_y = [] + self.increase_x = [] + self.increase_y = [] + self.decrease_x = [] + self.decrease_y = [] + self.get_all_xy() + self.separate_increase_decrease() + + def get_all_xy(self): + """ + Zip data to create OHLC shape + + OHLC shape: low to high vertical bar with + horizontal branches for open and close values. + If dates were added, the smallest date difference is calculated and + multiplied by .2 to get the length of the open and close branches. + If no date data was provided, the x-axis is a list of integers and the + length of the open and close branches is .2. + """ + self.all_y = list( + zip( + self.open, + self.open, + self.high, + self.low, + self.close, + self.close, + self.empty, + ) + ) + if self.dates is not None: + date_dif = [] + for i in range(len(self.dates) - 1): + date_dif.append(self.dates[i + 1] - self.dates[i]) + date_dif_min = (min(date_dif)) / 5 + self.all_x = [ + [x - date_dif_min, x, x, x, x, x + date_dif_min, None] + for x in self.dates + ] + else: + self.all_x = [ + [x - 0.2, x, x, x, x, x + 0.2, None] for x in range(len(self.open)) + ] + + def separate_increase_decrease(self): + """ + Separate data into two groups: increase and decrease + + (1) Increase, where close > open and + (2) Decrease, where close <= open + """ + for index in range(len(self.open)): + if self.close[index] is None: + pass + elif self.close[index] > self.open[index]: + self.increase_x.append(self.all_x[index]) + self.increase_y.append(self.all_y[index]) + else: + self.decrease_x.append(self.all_x[index]) + self.decrease_y.append(self.all_y[index]) + + def get_increase(self): + """ + Flatten increase data and get increase text + + :rtype (list, list, list): flat_increase_x: x-values for the increasing + trace, flat_increase_y: y=values for the increasing trace and + text_increase: hovertext for the increasing trace + """ + flat_increase_x = utils.flatten(self.increase_x) + flat_increase_y = utils.flatten(self.increase_y) + text_increase = ("Open", "Open", "High", "Low", "Close", "Close", "") * ( + len(self.increase_x) + ) + + return flat_increase_x, flat_increase_y, text_increase + + def get_decrease(self): + """ + Flatten decrease data and get decrease text + + :rtype (list, list, list): flat_decrease_x: x-values for the decreasing + trace, flat_decrease_y: y=values for the decreasing trace and + text_decrease: hovertext for the decreasing trace + """ + flat_decrease_x = utils.flatten(self.decrease_x) + flat_decrease_y = utils.flatten(self.decrease_y) + text_decrease = ("Open", "Open", "High", "Low", "Close", "Close", "") * ( + len(self.decrease_x) + ) + + return flat_decrease_x, flat_decrease_y, text_decrease diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py new file mode 100644 index 0000000..fa18222 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_quiver.py @@ -0,0 +1,265 @@ +import math + +from plotly import exceptions +from plotly.graph_objs import graph_objs +from plotly.figure_factory import utils + + +def create_quiver( + x, y, u, v, scale=0.1, arrow_scale=0.3, angle=math.pi / 9, scaleratio=None, **kwargs +): + """ + Returns data for a quiver plot. + + :param (list|ndarray) x: x coordinates of the arrow locations + :param (list|ndarray) y: y coordinates of the arrow locations + :param (list|ndarray) u: x components of the arrow vectors + :param (list|ndarray) v: y components of the arrow vectors + :param (float in [0,1]) scale: scales size of the arrows(ideally to + avoid overlap). Default = .1 + :param (float in [0,1]) arrow_scale: value multiplied to length of barb + to get length of arrowhead. Default = .3 + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (positive float) scaleratio: the ratio between the scale of the y-axis + and the scale of the x-axis (scale_y / scale_x). Default = None, the + scale ratio is not fixed. + :param kwargs: kwargs passed through plotly.graph_objs.Scatter + for more information on valid kwargs call + help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of quiver figure. + + Example 1: Trivial Quiver + + >>> from plotly.figure_factory import create_quiver + >>> import math + + >>> # 1 Arrow from (0,0) to (1,1) + >>> fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1) + >>> fig.show() + + + Example 2: Quiver plot using meshgrid + + >>> from plotly.figure_factory import create_quiver + + >>> import numpy as np + >>> import math + + >>> # Add data + >>> x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2)) + >>> u = np.cos(x)*y + >>> v = np.sin(x)*y + + >>> #Create quiver + >>> fig = create_quiver(x, y, u, v) + >>> fig.show() + + + Example 3: Styling the quiver plot + + >>> from plotly.figure_factory import create_quiver + >>> import numpy as np + >>> import math + + >>> # Add data + >>> x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5), + ... np.arange(-math.pi, math.pi, .5)) + >>> u = np.cos(x)*y + >>> v = np.sin(x)*y + + >>> # Create quiver + >>> fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6, + ... name='Wind Velocity', line=dict(width=1)) + + >>> # Add title to layout + >>> fig.update_layout(title='Quiver Plot') # doctest: +SKIP + >>> fig.show() + + + Example 4: Forcing a fix scale ratio to maintain the arrow length + + >>> from plotly.figure_factory import create_quiver + >>> import numpy as np + + >>> # Add data + >>> x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5)) + >>> u = x + >>> v = y + >>> angle = np.arctan(v / u) + >>> norm = 0.25 + >>> u = norm * np.cos(angle) + >>> v = norm * np.sin(angle) + + >>> # Create quiver with a fix scale ratio + >>> fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5) + >>> fig.show() + """ + utils.validate_equal_length(x, y, u, v) + utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale) + + if scaleratio is None: + quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle) + else: + quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio) + + barb_x, barb_y = quiver_obj.get_barbs() + arrow_x, arrow_y = quiver_obj.get_quiver_arrows() + + quiver_plot = graph_objs.Scatter( + x=barb_x + arrow_x, y=barb_y + arrow_y, mode="lines", **kwargs + ) + + data = [quiver_plot] + + if scaleratio is None: + layout = graph_objs.Layout(hovermode="closest") + else: + layout = graph_objs.Layout( + hovermode="closest", yaxis=dict(scaleratio=scaleratio, scaleanchor="x") + ) + + return graph_objs.Figure(data=data, layout=layout) + + +class _Quiver(object): + """ + Refer to FigureFactory.create_quiver() for docstring + """ + + def __init__(self, x, y, u, v, scale, arrow_scale, angle, scaleratio=1, **kwargs): + try: + x = utils.flatten(x) + except exceptions.PlotlyError: + pass + + try: + y = utils.flatten(y) + except exceptions.PlotlyError: + pass + + try: + u = utils.flatten(u) + except exceptions.PlotlyError: + pass + + try: + v = utils.flatten(v) + except exceptions.PlotlyError: + pass + + self.x = x + self.y = y + self.u = u + self.v = v + self.scale = scale + self.scaleratio = scaleratio + self.arrow_scale = arrow_scale + self.angle = angle + self.end_x = [] + self.end_y = [] + self.scale_uv() + barb_x, barb_y = self.get_barbs() + arrow_x, arrow_y = self.get_quiver_arrows() + + def scale_uv(self): + """ + Scales u and v to avoid overlap of the arrows. + + u and v are added to x and y to get the + endpoints of the arrows so a smaller scale value will + result in less overlap of arrows. + """ + self.u = [i * self.scale * self.scaleratio for i in self.u] + self.v = [i * self.scale for i in self.v] + + def get_barbs(self): + """ + Creates x and y startpoint and endpoint pairs + + After finding the endpoint of each barb this zips startpoint and + endpoint pairs to create 2 lists: x_values for barbs and y values + for barbs + + :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint + x_value pairs separated by a None to create the barb of the arrow, + and list of startpoint and endpoint y_value pairs separated by a + None to create the barb of the arrow. + """ + self.end_x = [i + j for i, j in zip(self.x, self.u)] + self.end_y = [i + j for i, j in zip(self.y, self.v)] + empty = [None] * len(self.x) + barb_x = utils.flatten(zip(self.x, self.end_x, empty)) + barb_y = utils.flatten(zip(self.y, self.end_y, empty)) + return barb_x, barb_y + + def get_quiver_arrows(self): + """ + Creates lists of x and y values to plot the arrows + + Gets length of each barb then calculates the length of each side of + the arrow. Gets angle of barb and applies angle to each side of the + arrowhead. Next uses arrow_scale to scale the length of arrowhead and + creates x and y values for arrowhead point1 and point2. Finally x and y + values for point1, endpoint and point2s for each arrowhead are + separated by a None and zipped to create lists of x and y values for + the arrows. + + :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2 + x_values separated by a None to create the arrowhead and list of + point1, endpoint, point2 y_values separated by a None to create + the barb of the arrow. + """ + dif_x = [i - j for i, j in zip(self.end_x, self.x)] + dif_y = [i - j for i, j in zip(self.end_y, self.y)] + + # Get barb lengths(default arrow length = 30% barb length) + barb_len = [None] * len(self.x) + for index in range(len(barb_len)): + barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index]) + + # Make arrow lengths + arrow_len = [None] * len(self.x) + arrow_len = [i * self.arrow_scale for i in barb_len] + + # Get barb angles + barb_ang = [None] * len(self.x) + for index in range(len(barb_ang)): + barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio) + + # Set angles to create arrow + ang1 = [i + self.angle for i in barb_ang] + ang2 = [i - self.angle for i in barb_ang] + + cos_ang1 = [None] * len(ang1) + for index in range(len(ang1)): + cos_ang1[index] = math.cos(ang1[index]) + seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)] + + sin_ang1 = [None] * len(ang1) + for index in range(len(ang1)): + sin_ang1[index] = math.sin(ang1[index]) + seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)] + + cos_ang2 = [None] * len(ang2) + for index in range(len(ang2)): + cos_ang2[index] = math.cos(ang2[index]) + seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)] + + sin_ang2 = [None] * len(ang2) + for index in range(len(ang2)): + sin_ang2[index] = math.sin(ang2[index]) + seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)] + + # Set coordinates to create arrow + for index in range(len(self.end_x)): + point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)] + point1_y = [i - j for i, j in zip(self.end_y, seg1_y)] + point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)] + point2_y = [i - j for i, j in zip(self.end_y, seg2_y)] + + # Combine lists to create arrow + empty = [None] * len(self.end_x) + arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty)) + arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty)) + return arrow_x, arrow_y diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py new file mode 100644 index 0000000..7589527 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_scatterplot.py @@ -0,0 +1,1135 @@ +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.subplots import make_subplots + +pd = optional_imports.get_module("pandas") + +DIAG_CHOICES = ["scatter", "histogram", "box"] +VALID_COLORMAP_TYPES = ["cat", "seq"] + + +def endpts_to_intervals(endpts): + """ + Returns a list of intervals for categorical colormaps + + Accepts a list or tuple of sequentially increasing numbers and returns + a list representation of the mathematical intervals with these numbers + as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]] + + :raises: (PlotlyError) If input is not a list or tuple + :raises: (PlotlyError) If the input contains a string + :raises: (PlotlyError) If any number does not increase after the + previous one in the sequence + """ + length = len(endpts) + # Check if endpts is a list or tuple + if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))): + raise exceptions.PlotlyError( + "The intervals_endpts argument must " + "be a list or tuple of a sequence " + "of increasing numbers." + ) + # Check if endpts contains only numbers + for item in endpts: + if isinstance(item, str): + raise exceptions.PlotlyError( + "The intervals_endpts argument " + "must be a list or tuple of a " + "sequence of increasing " + "numbers." + ) + # Check if numbers in endpts are increasing + for k in range(length - 1): + if endpts[k] >= endpts[k + 1]: + raise exceptions.PlotlyError( + "The intervals_endpts argument " + "must be a list or tuple of a " + "sequence of increasing " + "numbers." + ) + else: + intervals = [] + # add -inf to intervals + intervals.append([float("-inf"), endpts[0]]) + for k in range(length - 1): + interval = [] + interval.append(endpts[k]) + interval.append(endpts[k + 1]) + intervals.append(interval) + # add +inf to intervals + intervals.append([endpts[length - 1], float("inf")]) + return intervals + + +def hide_tick_labels_from_box_subplots(fig): + """ + Hides tick labels for box plots in scatterplotmatrix subplots. + """ + boxplot_xaxes = [] + for trace in fig["data"]: + if trace["type"] == "box": + # stores the xaxes which correspond to boxplot subplots + # since we use xaxis1, xaxis2, etc, in plotly.py + boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:])) + for xaxis in boxplot_xaxes: + fig["layout"][xaxis]["showticklabels"] = False + + +def validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs): + """ + Validates basic inputs for FigureFactory.create_scatterplotmatrix() + + :raises: (PlotlyError) If pandas is not imported + :raises: (PlotlyError) If pandas dataframe is not inputted + :raises: (PlotlyError) If pandas dataframe has <= 1 columns + :raises: (PlotlyError) If diagonal plot choice (diag) is not one of + the viable options + :raises: (PlotlyError) If colormap_type is not a valid choice + :raises: (PlotlyError) If kwargs contains 'size', 'color' or + 'colorscale' + """ + if not pd: + raise ImportError( + "FigureFactory.scatterplotmatrix requires a pandas DataFrame." + ) + + # Check if pandas dataframe + if not isinstance(df, pd.core.frame.DataFrame): + raise exceptions.PlotlyError( + "Dataframe not inputed. Please " + "use a pandas dataframe to pro" + "duce a scatterplot matrix." + ) + + # Check if dataframe is 1 column or less + if len(df.columns) <= 1: + raise exceptions.PlotlyError( + "Dataframe has only one column. To " + "use the scatterplot matrix, use at " + "least 2 columns." + ) + + # Check that diag parameter is a valid selection + if diag not in DIAG_CHOICES: + raise exceptions.PlotlyError( + "Make sure diag is set to one of {}".format(DIAG_CHOICES) + ) + + # Check that colormap_types is a valid selection + if colormap_type not in VALID_COLORMAP_TYPES: + raise exceptions.PlotlyError( + "Must choose a valid colormap type. " + "Either 'cat' or 'seq' for a cate" + "gorical and sequential colormap " + "respectively." + ) + + # Check for not 'size' or 'color' in 'marker' of **kwargs + if "marker" in kwargs: + FORBIDDEN_PARAMS = ["size", "color", "colorscale"] + if any(param in kwargs["marker"] for param in FORBIDDEN_PARAMS): + raise exceptions.PlotlyError( + "Your kwargs dictionary cannot " + "include the 'size', 'color' or " + "'colorscale' key words inside " + "the marker dict since 'size' is " + "already an argument of the " + "scatterplot matrix function and " + "both 'color' and 'colorscale " + "are set internally." + ) + + +def scatterplot(dataframe, headers, diag, size, height, width, title, **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix without index + + """ + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + # Insert traces into trace_list + for listy in dataframe: + for listx in dataframe: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram(x=listx, showlegend=False) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box(y=listx, name=None, showlegend=False) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + trace = graph_objs.Scatter( + x=listx, y=listy, mode="markers", showlegend=False, **kwargs + ) + trace_list.append(trace) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode="markers", + marker=dict(size=size), + showlegend=False, + **kwargs, + ) + trace_list.append(trace) + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + fig.append_trace(trace_list[trace_index], y_index, x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j) + fig["layout"][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = "yaxis{}".format(1 + (dim * j)) + fig["layout"][yaxis_key].update(title=headers[j]) + + fig["layout"].update(height=height, width=width, title=title, showlegend=True) + + hide_tick_labels_from_box_subplots(fig) + + return fig + + +def scatterplot_dict( + dataframe, + headers, + diag, + size, + height, + width, + title, + index, + index_vals, + endpts, + colormap, + colormap_type, + **kwargs, +): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix with both index and colormap picked. + Used if colormap is a dictionary with index values as keys pointing to + colors. Forces colormap_type to behave categorically because it would + not make sense colors are assigned to each index value and thus + implies that a categorical approach should be taken + + """ + + theme = colormap + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # create a dictionary for index_vals + unique_index_vals = {} + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals[name] = [] + + # Fill all the rest of the names into the dictionary + for name in sorted(unique_index_vals.keys()): + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if index_vals[j] == name: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=new_listx, marker=dict(color=theme[name]), showlegend=True + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict(color=theme[name]), + showlegend=True, + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + kwargs["marker"]["color"] = theme[name] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + showlegend=True, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + marker=dict(size=size, color=theme[name]), + showlegend=True, + **kwargs, + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict(color=theme[name]), + showlegend=False, + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict(color=theme[name]), + showlegend=False, + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + kwargs["marker"]["color"] = theme[name] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + showlegend=False, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + marker=dict(size=size, color=theme[name]), + showlegend=False, + **kwargs, + ) + # Push the trace into dictionary + unique_index_vals[name] = trace + trace_list.append(unique_index_vals) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for name in sorted(trace_list[trace_index].keys()): + fig.append_trace(trace_list[trace_index][name], y_index, x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j) + fig["layout"][xaxis_key].update(title=headers[j]) + + for j in range(dim): + yaxis_key = "yaxis{}".format(1 + (dim * j)) + fig["layout"][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == "histogram": + fig["layout"].update( + height=height, width=width, title=title, showlegend=True, barmode="stack" + ) + return fig + + else: + fig["layout"].update(height=height, width=width, title=title, showlegend=True) + return fig + + +def scatterplot_theme( + dataframe, + headers, + diag, + size, + height, + width, + title, + index, + index_vals, + endpts, + colormap, + colormap_type, + **kwargs, +): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix with both index and colormap picked + + """ + + # Check if index is made of string values + if isinstance(index_vals[0], str): + unique_index_vals = [] + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals.append(name) + n_colors_len = len(unique_index_vals) + + # Convert colormap to list of n RGB tuples + if colormap_type == "seq": + foo = clrs.color_parser(colormap, clrs.unlabel_rgb) + foo = clrs.n_colors(foo[0], foo[1], n_colors_len) + theme = clrs.color_parser(foo, clrs.label_rgb) + + if colormap_type == "cat": + # leave list of colors the same way + theme = colormap + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # create a dictionary for index_vals + unique_index_vals = {} + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals[name] = [] + + c_indx = 0 # color index + # Fill all the rest of the names into the dictionary + for name in sorted(unique_index_vals.keys()): + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if index_vals[j] == name: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict(color=theme[c_indx]), + showlegend=True, + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict(color=theme[c_indx]), + showlegend=True, + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + kwargs["marker"]["color"] = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + showlegend=True, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + marker=dict(size=size, color=theme[c_indx]), + showlegend=True, + **kwargs, + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict(color=theme[c_indx]), + showlegend=False, + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict(color=theme[c_indx]), + showlegend=False, + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + kwargs["marker"]["color"] = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + showlegend=False, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=name, + marker=dict(size=size, color=theme[c_indx]), + showlegend=False, + **kwargs, + ) + # Push the trace into dictionary + unique_index_vals[name] = trace + if c_indx >= (len(theme) - 1): + c_indx = -1 + c_indx += 1 + trace_list.append(unique_index_vals) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for name in sorted(trace_list[trace_index].keys()): + fig.append_trace(trace_list[trace_index][name], y_index, x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j) + fig["layout"][xaxis_key].update(title=headers[j]) + + for j in range(dim): + yaxis_key = "yaxis{}".format(1 + (dim * j)) + fig["layout"][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == "histogram": + fig["layout"].update( + height=height, + width=width, + title=title, + showlegend=True, + barmode="stack", + ) + return fig + + elif diag == "box": + fig["layout"].update( + height=height, width=width, title=title, showlegend=True + ) + return fig + + else: + fig["layout"].update( + height=height, width=width, title=title, showlegend=True + ) + return fig + + else: + if endpts: + intervals = utils.endpts_to_intervals(endpts) + + # Convert colormap to list of n RGB tuples + if colormap_type == "seq": + foo = clrs.color_parser(colormap, clrs.unlabel_rgb) + foo = clrs.n_colors(foo[0], foo[1], len(intervals)) + theme = clrs.color_parser(foo, clrs.label_rgb) + + if colormap_type == "cat": + # leave list of colors the same way + theme = colormap + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + interval_labels = {} + for interval in intervals: + interval_labels[str(interval)] = [] + + c_indx = 0 # color index + # Fill all the rest of the names into the dictionary + for interval in intervals: + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if interval[0] < index_vals[j] <= interval[1]: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict(color=theme[c_indx]), + showlegend=True, + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict(color=theme[c_indx]), + showlegend=True, + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + (kwargs["marker"]["color"]) = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=str(interval), + showlegend=True, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=str(interval), + marker=dict(size=size, color=theme[c_indx]), + showlegend=True, + **kwargs, + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict(color=theme[c_indx]), + showlegend=False, + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict(color=theme[c_indx]), + showlegend=False, + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + (kwargs["marker"]["color"]) = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=str(interval), + showlegend=False, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode="markers", + name=str(interval), + marker=dict(size=size, color=theme[c_indx]), + showlegend=False, + **kwargs, + ) + # Push the trace into dictionary + interval_labels[str(interval)] = trace + if c_indx >= (len(theme) - 1): + c_indx = -1 + c_indx += 1 + trace_list.append(interval_labels) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for interval in intervals: + fig.append_trace( + trace_list[trace_index][str(interval)], y_index, x_index + ) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j) + fig["layout"][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = "yaxis{}".format(1 + (dim * j)) + fig["layout"][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == "histogram": + fig["layout"].update( + height=height, + width=width, + title=title, + showlegend=True, + barmode="stack", + ) + return fig + + elif diag == "box": + fig["layout"].update( + height=height, width=width, title=title, showlegend=True + ) + return fig + + else: + fig["layout"].update( + height=height, width=width, title=title, showlegend=True + ) + return fig + + else: + theme = colormap + + # add a copy of rgb color to theme if it contains one color + if len(theme) <= 1: + theme.append(theme[0]) + + color = [] + for incr in range(len(theme)): + color.append([1.0 / (len(theme) - 1) * incr, theme[incr]]) + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Run through all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=listx, marker=dict(color=theme[0]), showlegend=False + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=listx, marker=dict(color=theme[0]), showlegend=False + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + kwargs["marker"]["color"] = index_vals + kwargs["marker"]["colorscale"] = color + kwargs["marker"]["showscale"] = True + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode="markers", + showlegend=False, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode="markers", + marker=dict( + size=size, + color=index_vals, + colorscale=color, + showscale=True, + ), + showlegend=False, + **kwargs, + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == "histogram"): + trace = graph_objs.Histogram( + x=listx, marker=dict(color=theme[0]), showlegend=False + ) + elif (listx == listy) and (diag == "box"): + trace = graph_objs.Box( + y=listx, marker=dict(color=theme[0]), showlegend=False + ) + else: + if "marker" in kwargs: + kwargs["marker"]["size"] = size + kwargs["marker"]["color"] = index_vals + kwargs["marker"]["colorscale"] = color + kwargs["marker"]["showscale"] = False + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode="markers", + showlegend=False, + **kwargs, + ) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode="markers", + marker=dict( + size=size, + color=index_vals, + colorscale=color, + showscale=False, + ), + showlegend=False, + **kwargs, + ) + # Push the trace into list + trace_list.append(trace) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + fig.append_trace(trace_list[trace_index], y_index, x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = "xaxis{}".format((dim * dim) - dim + 1 + j) + fig["layout"][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = "yaxis{}".format(1 + (dim * j)) + fig["layout"][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == "histogram": + fig["layout"].update( + height=height, + width=width, + title=title, + showlegend=True, + barmode="stack", + ) + return fig + + elif diag == "box": + fig["layout"].update( + height=height, width=width, title=title, showlegend=True + ) + return fig + + else: + fig["layout"].update( + height=height, width=width, title=title, showlegend=True + ) + return fig + + +def create_scatterplotmatrix( + df, + index=None, + endpts=None, + diag="scatter", + height=500, + width=500, + size=6, + title="Scatterplot Matrix", + colormap=None, + colormap_type="cat", + dataframe=None, + headers=None, + index_vals=None, + **kwargs, +): + """ + Returns data for a scatterplot matrix; + **deprecated**, + use instead the plotly.graph_objects trace + :class:`plotly.graph_objects.Splom`. + + :param (array) df: array of the data with column headers + :param (str) index: name of the index column in data array + :param (list|tuple) endpts: takes an increasing sequece of numbers + that defines intervals on the real line. They are used to group + the entries in an index of numbers into their corresponding + interval and therefore can be treated as categorical data + :param (str) diag: sets the chart type for the main diagonal plots. + The options are 'scatter', 'histogram' and 'box'. + :param (int|float) height: sets the height of the chart + :param (int|float) width: sets the width of the chart + :param (float) size: sets the marker size (in px) + :param (str) title: the title label of the scatterplot matrix + :param (str|tuple|list|dict) colormap: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colormap is a list, it must contain valid color types as its + members. + If colormap is a dictionary, all the string entries in + the index column must be a key in colormap. In this case, the + colormap_type is forced to 'cat' or categorical + :param (str) colormap_type: determines how colormap is interpreted. + Valid choices are 'seq' (sequential) and 'cat' (categorical). If + 'seq' is selected, only the first two colors in colormap will be + considered (when colormap is a list) and the index values will be + linearly interpolated between those two colors. This option is + forced if all index values are numeric. + If 'cat' is selected, a color from colormap will be assigned to + each category from index, including the intervals if endpts is + being used + :param (dict) **kwargs: a dictionary of scatterplot arguments + The only forbidden parameters are 'size', 'color' and + 'colorscale' in 'marker' + + Example 1: Vanilla Scatterplot Matrix + + >>> from plotly.graph_objs import graph_objs + >>> from plotly.figure_factory import create_scatterplotmatrix + + >>> import numpy as np + >>> import pandas as pd + + >>> # Create dataframe + >>> df = pd.DataFrame(np.random.randn(10, 2), + ... columns=['Column 1', 'Column 2']) + + >>> # Create scatterplot matrix + >>> fig = create_scatterplotmatrix(df) + >>> fig.show() + + + Example 2: Indexing a Column + + >>> from plotly.graph_objs import graph_objs + >>> from plotly.figure_factory import create_scatterplotmatrix + + >>> import numpy as np + >>> import pandas as pd + + >>> # Create dataframe with index + >>> df = pd.DataFrame(np.random.randn(10, 2), + ... columns=['A', 'B']) + + >>> # Add another column of strings to the dataframe + >>> df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', + ... 'grape', 'pear', 'pear', 'apple', 'pear']) + + >>> # Create scatterplot matrix + >>> fig = create_scatterplotmatrix(df, index='Fruit', size=10) + >>> fig.show() + + + Example 3: Styling the Diagonal Subplots + + >>> from plotly.graph_objs import graph_objs + >>> from plotly.figure_factory import create_scatterplotmatrix + + >>> import numpy as np + >>> import pandas as pd + + >>> # Create dataframe with index + >>> df = pd.DataFrame(np.random.randn(10, 4), + ... columns=['A', 'B', 'C', 'D']) + + >>> # Add another column of strings to the dataframe + >>> df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', + ... 'grape', 'pear', 'pear', 'apple', 'pear']) + + >>> # Create scatterplot matrix + >>> fig = create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000, + ... width=1000) + >>> fig.show() + + + Example 4: Use a Theme to Style the Subplots + + >>> from plotly.graph_objs import graph_objs + >>> from plotly.figure_factory import create_scatterplotmatrix + + >>> import numpy as np + >>> import pandas as pd + + >>> # Create dataframe with random data + >>> df = pd.DataFrame(np.random.randn(100, 3), + ... columns=['A', 'B', 'C']) + + >>> # Create scatterplot matrix using a built-in + >>> # Plotly palette scale and indexing column 'A' + >>> fig = create_scatterplotmatrix(df, diag='histogram', index='A', + ... colormap='Blues', height=800, width=800) + >>> fig.show() + + + Example 5: Example 4 with Interval Factoring + + >>> from plotly.graph_objs import graph_objs + >>> from plotly.figure_factory import create_scatterplotmatrix + + >>> import numpy as np + >>> import pandas as pd + + >>> # Create dataframe with random data + >>> df = pd.DataFrame(np.random.randn(100, 3), + ... columns=['A', 'B', 'C']) + + >>> # Create scatterplot matrix using a list of 2 rgb tuples + >>> # and endpoints at -1, 0 and 1 + >>> fig = create_scatterplotmatrix(df, diag='histogram', index='A', + ... colormap=['rgb(140, 255, 50)', + ... 'rgb(170, 60, 115)', '#6c4774', + ... (0.5, 0.1, 0.8)], + ... endpts=[-1, 0, 1], height=800, width=800) + >>> fig.show() + + + Example 6: Using the colormap as a Dictionary + + >>> from plotly.graph_objs import graph_objs + >>> from plotly.figure_factory import create_scatterplotmatrix + + >>> import numpy as np + >>> import pandas as pd + >>> import random + + >>> # Create dataframe with random data + >>> df = pd.DataFrame(np.random.randn(100, 3), + ... columns=['Column A', + ... 'Column B', + ... 'Column C']) + + >>> # Add new color column to dataframe + >>> new_column = [] + >>> strange_colors = ['turquoise', 'limegreen', 'goldenrod'] + + >>> for j in range(100): + ... new_column.append(random.choice(strange_colors)) + >>> df['Colors'] = pd.Series(new_column, index=df.index) + + >>> # Create scatterplot matrix using a dictionary of hex color values + >>> # which correspond to actual color names in 'Colors' column + >>> fig = create_scatterplotmatrix( + ... df, diag='box', index='Colors', + ... colormap= dict( + ... turquoise = '#00F5FF', + ... limegreen = '#32CD32', + ... goldenrod = '#DAA520' + ... ), + ... colormap_type='cat', + ... height=800, width=800 + ... ) + >>> fig.show() + """ + # TODO: protected until #282 + if dataframe is None: + dataframe = [] + if headers is None: + headers = [] + if index_vals is None: + index_vals = [] + + validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs) + + # Validate colormap + if isinstance(colormap, dict): + colormap = clrs.validate_colors_dict(colormap, "rgb") + elif isinstance(colormap, str) and "rgb" not in colormap and "#" not in colormap: + if colormap not in clrs.PLOTLY_SCALES.keys(): + raise exceptions.PlotlyError( + "If 'colormap' is a string, it must be the name " + "of a Plotly Colorscale. The available colorscale " + "names are {}".format(clrs.PLOTLY_SCALES.keys()) + ) + else: + # TODO change below to allow the correct Plotly colorscale + colormap = clrs.colorscale_to_colors(clrs.PLOTLY_SCALES[colormap]) + # keep only first and last item - fix later + colormap = [colormap[0]] + [colormap[-1]] + colormap = clrs.validate_colors(colormap, "rgb") + else: + colormap = clrs.validate_colors(colormap, "rgb") + + if not index: + for name in df: + headers.append(name) + for name in headers: + dataframe.append(df[name].values.tolist()) + # Check for same data-type in df columns + utils.validate_dataframe(dataframe) + figure = scatterplot( + dataframe, headers, diag, size, height, width, title, **kwargs + ) + return figure + else: + # Validate index selection + if index not in df: + raise exceptions.PlotlyError( + "Make sure you set the index " + "input variable to one of the " + "column names of your " + "dataframe." + ) + index_vals = df[index].values.tolist() + for name in df: + if name != index: + headers.append(name) + for name in headers: + dataframe.append(df[name].values.tolist()) + + # check for same data-type in each df column + utils.validate_dataframe(dataframe) + utils.validate_index(index_vals) + + # check if all colormap keys are in the index + # if colormap is a dictionary + if isinstance(colormap, dict): + for key in colormap: + if not all(index in colormap for index in index_vals): + raise exceptions.PlotlyError( + "If colormap is a " + "dictionary, all the " + "names in the index " + "must be keys." + ) + figure = scatterplot_dict( + dataframe, + headers, + diag, + size, + height, + width, + title, + index, + index_vals, + endpts, + colormap, + colormap_type, + **kwargs, + ) + return figure + + else: + figure = scatterplot_theme( + dataframe, + headers, + diag, + size, + height, + width, + title, + index, + index_vals, + endpts, + colormap, + colormap_type, + **kwargs, + ) + return figure diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py new file mode 100644 index 0000000..55b74c3 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_streamline.py @@ -0,0 +1,406 @@ +import math + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +np = optional_imports.get_module("numpy") + + +def validate_streamline(x, y): + """ + Streamline-specific validations + + Specifically, this checks that x and y are both evenly spaced, + and that the package numpy is available. + + See FigureFactory.create_streamline() for params + + :raises: (ImportError) If numpy is not available. + :raises: (PlotlyError) If x is not evenly spaced. + :raises: (PlotlyError) If y is not evenly spaced. + """ + if np is False: + raise ImportError("FigureFactory.create_streamline requires numpy") + for index in range(len(x) - 1): + if ((x[index + 1] - x[index]) - (x[1] - x[0])) > 0.0001: + raise exceptions.PlotlyError( + "x must be a 1 dimensional, evenly spaced array" + ) + for index in range(len(y) - 1): + if ((y[index + 1] - y[index]) - (y[1] - y[0])) > 0.0001: + raise exceptions.PlotlyError( + "y must be a 1 dimensional, evenly spaced array" + ) + + +def create_streamline( + x, y, u, v, density=1, angle=math.pi / 9, arrow_scale=0.09, **kwargs +): + """ + Returns data for a streamline plot. + + :param (list|ndarray) x: 1 dimensional, evenly spaced list or array + :param (list|ndarray) y: 1 dimensional, evenly spaced list or array + :param (ndarray) u: 2 dimensional array + :param (ndarray) v: 2 dimensional array + :param (float|int) density: controls the density of streamlines in + plot. This is multiplied by 30 to scale similiarly to other + available streamline functions such as matplotlib. + Default = 1 + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (float in [0,1]) arrow_scale: value to scale length of arrowhead + Default = .09 + :param kwargs: kwargs passed through plotly.graph_objs.Scatter + for more information on valid kwargs call + help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of streamline figure. + + Example 1: Plot simple streamline and increase arrow size + + >>> from plotly.figure_factory import create_streamline + >>> import plotly.graph_objects as go + >>> import numpy as np + >>> import math + + >>> # Add data + >>> x = np.linspace(-3, 3, 100) + >>> y = np.linspace(-3, 3, 100) + >>> Y, X = np.meshgrid(x, y) + >>> u = -1 - X**2 + Y + >>> v = 1 + X - Y**2 + >>> u = u.T # Transpose + >>> v = v.T # Transpose + + >>> # Create streamline + >>> fig = create_streamline(x, y, u, v, arrow_scale=.1) + >>> fig.show() + + Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython + + >>> from plotly.figure_factory import create_streamline + >>> import numpy as np + >>> import math + + >>> # Add data + >>> N = 50 + >>> x_start, x_end = -2.0, 2.0 + >>> y_start, y_end = -1.0, 1.0 + >>> x = np.linspace(x_start, x_end, N) + >>> y = np.linspace(y_start, y_end, N) + >>> X, Y = np.meshgrid(x, y) + >>> ss = 5.0 + >>> x_s, y_s = -1.0, 0.0 + + >>> # Compute the velocity field on the mesh grid + >>> u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2) + >>> v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2) + + >>> # Create streamline + >>> fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline') + + >>> # Add source point + >>> point = go.Scatter(x=[x_s], y=[y_s], mode='markers', + ... marker_size=14, name='source point') + + >>> fig.add_trace(point) # doctest: +SKIP + >>> fig.show() + """ + utils.validate_equal_length(x, y) + utils.validate_equal_length(u, v) + validate_streamline(x, y) + utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale) + + streamline_x, streamline_y = _Streamline( + x, y, u, v, density, angle, arrow_scale + ).sum_streamlines() + arrow_x, arrow_y = _Streamline( + x, y, u, v, density, angle, arrow_scale + ).get_streamline_arrows() + + streamline = graph_objs.Scatter( + x=streamline_x + arrow_x, y=streamline_y + arrow_y, mode="lines", **kwargs + ) + + data = [streamline] + layout = graph_objs.Layout(hovermode="closest") + + return graph_objs.Figure(data=data, layout=layout) + + +class _Streamline(object): + """ + Refer to FigureFactory.create_streamline() for docstring + """ + + def __init__(self, x, y, u, v, density, angle, arrow_scale, **kwargs): + self.x = np.array(x) + self.y = np.array(y) + self.u = np.array(u) + self.v = np.array(v) + self.angle = angle + self.arrow_scale = arrow_scale + self.density = int(30 * density) # Scale similarly to other functions + self.delta_x = self.x[1] - self.x[0] + self.delta_y = self.y[1] - self.y[0] + self.val_x = self.x + self.val_y = self.y + + # Set up spacing + self.blank = np.zeros((self.density, self.density)) + self.spacing_x = len(self.x) / float(self.density - 1) + self.spacing_y = len(self.y) / float(self.density - 1) + self.trajectories = [] + + # Rescale speed onto axes-coordinates + self.u = self.u / (self.x[-1] - self.x[0]) + self.v = self.v / (self.y[-1] - self.y[0]) + self.speed = np.sqrt(self.u**2 + self.v**2) + + # Rescale u and v for integrations. + self.u *= len(self.x) + self.v *= len(self.y) + self.st_x = [] + self.st_y = [] + self.get_streamlines() + streamline_x, streamline_y = self.sum_streamlines() + arrows_x, arrows_y = self.get_streamline_arrows() + + def blank_pos(self, xi, yi): + """ + Set up positions for trajectories to be used with rk4 function. + """ + return (int((xi / self.spacing_x) + 0.5), int((yi / self.spacing_y) + 0.5)) + + def value_at(self, a, xi, yi): + """ + Set up for RK4 function, based on Bokeh's streamline code + """ + if isinstance(xi, np.ndarray): + self.x = xi.astype(int) + self.y = yi.astype(int) + else: + self.val_x = int(xi) + self.val_y = int(yi) + a00 = a[self.val_y, self.val_x] + a01 = a[self.val_y, self.val_x + 1] + a10 = a[self.val_y + 1, self.val_x] + a11 = a[self.val_y + 1, self.val_x + 1] + xt = xi - self.val_x + yt = yi - self.val_y + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + return a0 * (1 - yt) + a1 * yt + + def rk4_integrate(self, x0, y0): + """ + RK4 forward and back trajectories from the initial conditions. + + Adapted from Bokeh's streamline -uses Runge-Kutta method to fill + x and y trajectories then checks length of traj (s in units of axes) + """ + + def f(xi, yi): + dt_ds = 1.0 / self.value_at(self.speed, xi, yi) + ui = self.value_at(self.u, xi, yi) + vi = self.value_at(self.v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def g(xi, yi): + dt_ds = 1.0 / self.value_at(self.speed, xi, yi) + ui = self.value_at(self.u, xi, yi) + vi = self.value_at(self.v, xi, yi) + return -ui * dt_ds, -vi * dt_ds + + def check(xi, yi): + return (0 <= xi < len(self.x) - 1) and (0 <= yi < len(self.y) - 1) + + xb_changes = [] + yb_changes = [] + + def rk4(x0, y0, f): + ds = 0.01 + stotal = 0 + xi = x0 + yi = y0 + xb, yb = self.blank_pos(xi, yi) + xf_traj = [] + yf_traj = [] + while check(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + 0.5 * ds * k1x, yi + 0.5 * ds * k1y) + k3x, k3y = f(xi + 0.5 * ds * k2x, yi + 0.5 * ds * k2y) + k4x, k4y = f(xi + ds * k3x, yi + ds * k3y) + except IndexError: + break + xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.0 + yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.0 + if not check(xi, yi): + break + stotal += ds + new_xb, new_yb = self.blank_pos(xi, yi) + if new_xb != xb or new_yb != yb: + if self.blank[new_yb, new_xb] == 0: + self.blank[new_yb, new_xb] = 1 + xb_changes.append(new_xb) + yb_changes.append(new_yb) + xb = new_xb + yb = new_yb + else: + break + if stotal > 2: + break + return stotal, xf_traj, yf_traj + + sf, xf_traj, yf_traj = rk4(x0, y0, f) + sb, xb_traj, yb_traj = rk4(x0, y0, g) + stotal = sf + sb + x_traj = xb_traj[::-1] + xf_traj[1:] + y_traj = yb_traj[::-1] + yf_traj[1:] + + if len(x_traj) < 1: + return None + if stotal > 0.2: + initxb, inityb = self.blank_pos(x0, y0) + self.blank[inityb, initxb] = 1 + return x_traj, y_traj + else: + for xb, yb in zip(xb_changes, yb_changes): + self.blank[yb, xb] = 0 + return None + + def traj(self, xb, yb): + """ + Integrate trajectories + + :param (int) xb: results of passing xi through self.blank_pos + :param (int) xy: results of passing yi through self.blank_pos + + Calculate each trajectory based on rk4 integrate method. + """ + + if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density: + return + if self.blank[yb, xb] == 0: + t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y) + if t is not None: + self.trajectories.append(t) + + def get_streamlines(self): + """ + Get streamlines by building trajectory set. + """ + for indent in range(self.density // 2): + for xi in range(self.density - 2 * indent): + self.traj(xi + indent, indent) + self.traj(xi + indent, self.density - 1 - indent) + self.traj(indent, xi + indent) + self.traj(self.density - 1 - indent, xi + indent) + + self.st_x = [ + np.array(t[0]) * self.delta_x + self.x[0] for t in self.trajectories + ] + self.st_y = [ + np.array(t[1]) * self.delta_y + self.y[0] for t in self.trajectories + ] + + for index in range(len(self.st_x)): + self.st_x[index] = self.st_x[index].tolist() + self.st_x[index].append(np.nan) + + for index in range(len(self.st_y)): + self.st_y[index] = self.st_y[index].tolist() + self.st_y[index].append(np.nan) + + def get_streamline_arrows(self): + """ + Makes an arrow for each streamline. + + Gets angle of streamline at 1/3 mark and creates arrow coordinates + based off of user defined angle and arrow_scale. + + :param (array) st_x: x-values for all streamlines + :param (array) st_y: y-values for all streamlines + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (float in [0,1]) arrow_scale: value to scale length of arrowhead + Default = .09 + :rtype (list, list) arrows_x: x-values to create arrowhead and + arrows_y: y-values to create arrowhead + """ + arrow_end_x = np.empty((len(self.st_x))) + arrow_end_y = np.empty((len(self.st_y))) + arrow_start_x = np.empty((len(self.st_x))) + arrow_start_y = np.empty((len(self.st_y))) + for index in range(len(self.st_x)): + arrow_end_x[index] = self.st_x[index][int(len(self.st_x[index]) / 3)] + arrow_start_x[index] = self.st_x[index][ + (int(len(self.st_x[index]) / 3)) - 1 + ] + arrow_end_y[index] = self.st_y[index][int(len(self.st_y[index]) / 3)] + arrow_start_y[index] = self.st_y[index][ + (int(len(self.st_y[index]) / 3)) - 1 + ] + + dif_x = arrow_end_x - arrow_start_x + dif_y = arrow_end_y - arrow_start_y + + orig_err = np.geterr() + np.seterr(divide="ignore", invalid="ignore") + streamline_ang = np.arctan(dif_y / dif_x) + np.seterr(**orig_err) + + ang1 = streamline_ang + (self.angle) + ang2 = streamline_ang - (self.angle) + + seg1_x = np.cos(ang1) * self.arrow_scale + seg1_y = np.sin(ang1) * self.arrow_scale + seg2_x = np.cos(ang2) * self.arrow_scale + seg2_y = np.sin(ang2) * self.arrow_scale + + point1_x = np.empty((len(dif_x))) + point1_y = np.empty((len(dif_y))) + point2_x = np.empty((len(dif_x))) + point2_y = np.empty((len(dif_y))) + + for index in range(len(dif_x)): + if dif_x[index] >= 0: + point1_x[index] = arrow_end_x[index] - seg1_x[index] + point1_y[index] = arrow_end_y[index] - seg1_y[index] + point2_x[index] = arrow_end_x[index] - seg2_x[index] + point2_y[index] = arrow_end_y[index] - seg2_y[index] + else: + point1_x[index] = arrow_end_x[index] + seg1_x[index] + point1_y[index] = arrow_end_y[index] + seg1_y[index] + point2_x[index] = arrow_end_x[index] + seg2_x[index] + point2_y[index] = arrow_end_y[index] + seg2_y[index] + + space = np.empty((len(point1_x))) + space[:] = np.nan + + # Combine arrays into array + arrows_x = np.array([point1_x, arrow_end_x, point2_x, space]) + arrows_x = arrows_x.flatten("F") + arrows_x = arrows_x.tolist() + + # Combine arrays into array + arrows_y = np.array([point1_y, arrow_end_y, point2_y, space]) + arrows_y = arrows_y.flatten("F") + arrows_y = arrows_y.tolist() + + return arrows_x, arrows_y + + def sum_streamlines(self): + """ + Makes all streamlines readable as a single trace. + + :rtype (list, list): streamline_x: all x values for each streamline + combined into single list and streamline_y: all y values for each + streamline combined into single list + """ + streamline_x = sum(self.st_x, []) + streamline_y = sum(self.st_y, []) + return streamline_x, streamline_y diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py new file mode 100644 index 0000000..bda731f --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_table.py @@ -0,0 +1,280 @@ +from plotly import exceptions, optional_imports +from plotly.graph_objs import graph_objs + +pd = optional_imports.get_module("pandas") + + +def validate_table(table_text, font_colors): + """ + Table-specific validations + + Check that font_colors is supplied correctly (1, 3, or len(text) + colors). + + :raises: (PlotlyError) If font_colors is supplied incorretly. + + See FigureFactory.create_table() for params + """ + font_colors_len_options = [1, 3, len(table_text)] + if len(font_colors) not in font_colors_len_options: + raise exceptions.PlotlyError( + "Oops, font_colors should be a list of length 1, 3 or len(text)" + ) + + +def create_table( + table_text, + colorscale=None, + font_colors=None, + index=False, + index_title="", + annotation_offset=0.45, + height_constant=30, + hoverinfo="none", + **kwargs, +): + """ + Function that creates data tables. + + See also the plotly.graph_objects trace + :class:`plotly.graph_objects.Table` + + :param (pandas.Dataframe | list[list]) text: data for table. + :param (str|list[list]) colorscale: Colorscale for table where the + color at value 0 is the header color, .5 is the first table color + and 1 is the second table color. (Set .5 and 1 to avoid the striped + table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'], + [1, '#ffffff']] + :param (list) font_colors: Color for fonts in table. Can be a single + color, three colors, or a color for each row in the table. + Default=['#000000'] (black text for the entire table) + :param (int) height_constant: Constant multiplied by # of rows to + create table height. Default=30. + :param (bool) index: Create (header-colored) index column index from + Pandas dataframe or list[0] for each list in text. Default=False. + :param (string) index_title: Title for index column. Default=''. + :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. + These kwargs describe other attributes about the annotated Heatmap + trace such as the colorscale. For more information on valid kwargs + call help(plotly.graph_objs.Heatmap) + + Example 1: Simple Plotly Table + + >>> from plotly.figure_factory import create_table + + >>> text = [['Country', 'Year', 'Population'], + ... ['US', 2000, 282200000], + ... ['Canada', 2000, 27790000], + ... ['US', 2010, 309000000], + ... ['Canada', 2010, 34000000]] + + >>> table = create_table(text) + >>> table.show() + + Example 2: Table with Custom Coloring + + >>> from plotly.figure_factory import create_table + >>> text = [['Country', 'Year', 'Population'], + ... ['US', 2000, 282200000], + ... ['Canada', 2000, 27790000], + ... ['US', 2010, 309000000], + ... ['Canada', 2010, 34000000]] + >>> table = create_table(text, + ... colorscale=[[0, '#000000'], + ... [.5, '#80beff'], + ... [1, '#cce5ff']], + ... font_colors=['#ffffff', '#000000', + ... '#000000']) + >>> table.show() + + Example 3: Simple Plotly Table with Pandas + + >>> from plotly.figure_factory import create_table + >>> import pandas as pd + >>> df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t') + >>> df_p = df[0:25] + >>> table_simple = create_table(df_p) + >>> table_simple.show() + + """ + + # Avoiding mutables in the call signature + colorscale = ( + colorscale + if colorscale is not None + else [[0, "#00083e"], [0.5, "#ededee"], [1, "#ffffff"]] + ) + font_colors = ( + font_colors if font_colors is not None else ["#ffffff", "#000000", "#000000"] + ) + + validate_table(table_text, font_colors) + table_matrix = _Table( + table_text, + colorscale, + font_colors, + index, + index_title, + annotation_offset, + **kwargs, + ).get_table_matrix() + annotations = _Table( + table_text, + colorscale, + font_colors, + index, + index_title, + annotation_offset, + **kwargs, + ).make_table_annotations() + + trace = dict( + type="heatmap", + z=table_matrix, + opacity=0.75, + colorscale=colorscale, + showscale=False, + hoverinfo=hoverinfo, + **kwargs, + ) + + data = [trace] + layout = dict( + annotations=annotations, + height=len(table_matrix) * height_constant + 50, + margin=dict(t=0, b=0, r=0, l=0), + yaxis=dict( + autorange="reversed", + zeroline=False, + gridwidth=2, + ticks="", + dtick=1, + tick0=0.5, + showticklabels=False, + ), + xaxis=dict( + zeroline=False, + gridwidth=2, + ticks="", + dtick=1, + tick0=-0.5, + showticklabels=False, + ), + ) + return graph_objs.Figure(data=data, layout=layout) + + +class _Table(object): + """ + Refer to TraceFactory.create_table() for docstring + """ + + def __init__( + self, + table_text, + colorscale, + font_colors, + index, + index_title, + annotation_offset, + **kwargs, + ): + if pd and isinstance(table_text, pd.DataFrame): + headers = table_text.columns.tolist() + table_text_index = table_text.index.tolist() + table_text = table_text.values.tolist() + table_text.insert(0, headers) + if index: + table_text_index.insert(0, index_title) + for i in range(len(table_text)): + table_text[i].insert(0, table_text_index[i]) + self.table_text = table_text + self.colorscale = colorscale + self.font_colors = font_colors + self.index = index + self.annotation_offset = annotation_offset + self.x = range(len(table_text[0])) + self.y = range(len(table_text)) + + def get_table_matrix(self): + """ + Create z matrix to make heatmap with striped table coloring + + :rtype (list[list]) table_matrix: z matrix to make heatmap with striped + table coloring. + """ + header = [0] * len(self.table_text[0]) + odd_row = [0.5] * len(self.table_text[0]) + even_row = [1] * len(self.table_text[0]) + table_matrix = [None] * len(self.table_text) + table_matrix[0] = header + for i in range(1, len(self.table_text), 2): + table_matrix[i] = odd_row + for i in range(2, len(self.table_text), 2): + table_matrix[i] = even_row + if self.index: + for array in table_matrix: + array[0] = 0 + return table_matrix + + def get_table_font_color(self): + """ + Fill font-color array. + + Table text color can vary by row so this extends a single color or + creates an array to set a header color and two alternating colors to + create the striped table pattern. + + :rtype (list[list]) all_font_colors: list of font colors for each row + in table. + """ + if len(self.font_colors) == 1: + all_font_colors = self.font_colors * len(self.table_text) + elif len(self.font_colors) == 3: + all_font_colors = list(range(len(self.table_text))) + all_font_colors[0] = self.font_colors[0] + for i in range(1, len(self.table_text), 2): + all_font_colors[i] = self.font_colors[1] + for i in range(2, len(self.table_text), 2): + all_font_colors[i] = self.font_colors[2] + elif len(self.font_colors) == len(self.table_text): + all_font_colors = self.font_colors + else: + all_font_colors = ["#000000"] * len(self.table_text) + return all_font_colors + + def make_table_annotations(self): + """ + Generate annotations to fill in table text + + :rtype (list) annotations: list of annotations for each cell of the + table. + """ + all_font_colors = _Table.get_table_font_color(self) + annotations = [] + for n, row in enumerate(self.table_text): + for m, val in enumerate(row): + # Bold text in header and index + format_text = ( + "<b>" + str(val) + "</b>" + if n == 0 or self.index and m < 1 + else str(val) + ) + # Match font color of index to font color of header + font_color = ( + self.font_colors[0] if self.index and m == 0 else all_font_colors[n] + ) + annotations.append( + graph_objs.layout.Annotation( + text=format_text, + x=self.x[m] - self.annotation_offset, + y=self.y[n], + xref="x1", + yref="y1", + align="left", + xanchor="left", + font=dict(color=font_color), + showarrow=False, + ) + ) + return annotations diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py new file mode 100644 index 0000000..4cdcf17 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_ternary_contour.py @@ -0,0 +1,692 @@ +import plotly.colors as clrs +from plotly.graph_objs import graph_objs as go +from plotly import exceptions +from plotly import optional_imports + +from skimage import measure + +np = optional_imports.get_module("numpy") +scipy_interp = optional_imports.get_module("scipy.interpolate") + +# -------------------------- Layout ------------------------------ + + +def _ternary_layout( + title="Ternary contour plot", width=550, height=525, pole_labels=["a", "b", "c"] +): + """ + Layout of ternary contour plot, to be passed to ``go.FigureWidget`` + object. + + Parameters + ========== + title : str or None + Title of ternary plot + width : int + Figure width. + height : int + Figure height. + pole_labels : str, default ['a', 'b', 'c'] + Names of the three poles of the triangle. + """ + return dict( + title=title, + width=width, + height=height, + ternary=dict( + sum=1, + aaxis=dict( + title=dict(text=pole_labels[0]), min=0.01, linewidth=2, ticks="outside" + ), + baxis=dict( + title=dict(text=pole_labels[1]), min=0.01, linewidth=2, ticks="outside" + ), + caxis=dict( + title=dict(text=pole_labels[2]), min=0.01, linewidth=2, ticks="outside" + ), + ), + showlegend=False, + ) + + +# ------------- Transformations of coordinates ------------------- + + +def _replace_zero_coords(ternary_data, delta=0.0005): + """ + Replaces zero ternary coordinates with delta and normalize the new + triplets (a, b, c). + + Parameters + ---------- + + ternary_data : ndarray of shape (N, 3) + + delta : float + Small float to regularize logarithm. + + Notes + ----- + Implements a method + by J. A. Martin-Fernandez, C. Barcelo-Vidal, V. Pawlowsky-Glahn, + Dealing with zeros and missing values in compositional data sets + using nonparametric imputation, Mathematical Geology 35 (2003), + pp 253-278. + """ + zero_mask = ternary_data == 0 + is_any_coord_zero = np.any(zero_mask, axis=0) + + unity_complement = 1 - delta * is_any_coord_zero + if np.any(unity_complement) < 0: + raise ValueError( + "The provided value of delta led to negative" + "ternary coords.Set a smaller delta" + ) + ternary_data = np.where(zero_mask, delta, unity_complement * ternary_data) + return ternary_data + + +def _ilr_transform(barycentric): + """ + Perform Isometric Log-Ratio on barycentric (compositional) data. + + Parameters + ---------- + barycentric: ndarray of shape (3, N) + Barycentric coordinates. + + References + ---------- + "An algebraic method to compute isometric logratio transformation and + back transformation of compositional data", Jarauta-Bragulat, E., + Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the + Intl Assoc for Math Geology, 2003, pp 31-30. + """ + barycentric = np.asarray(barycentric) + x_0 = np.log(barycentric[0] / barycentric[1]) / np.sqrt(2) + x_1 = ( + 1.0 / np.sqrt(6) * np.log(barycentric[0] * barycentric[1] / barycentric[2] ** 2) + ) + ilr_tdata = np.stack((x_0, x_1)) + return ilr_tdata + + +def _ilr_inverse(x): + """ + Perform inverse Isometric Log-Ratio (ILR) transform to retrieve + barycentric (compositional) data. + + Parameters + ---------- + x : array of shape (2, N) + Coordinates in ILR space. + + References + ---------- + "An algebraic method to compute isometric logratio transformation and + back transformation of compositional data", Jarauta-Bragulat, E., + Buenestado, P.; Hervada-Sala, C., in Proc. of the Annual Conf. of the + Intl Assoc for Math Geology, 2003, pp 31-30. + """ + x = np.array(x) + matrix = np.array([[0.5, 1, 1.0], [-0.5, 1, 1.0], [0.0, 0.0, 1.0]]) + s = np.sqrt(2) / 2 + t = np.sqrt(3 / 2) + Sk = np.einsum("ik, kj -> ij", np.array([[s, t], [-s, t]]), x) + Z = -np.log(1 + np.exp(Sk).sum(axis=0)) + log_barycentric = np.einsum( + "ik, kj -> ij", matrix, np.stack((2 * s * x[0], t * x[1], Z)) + ) + iilr_tdata = np.exp(log_barycentric) + return iilr_tdata + + +def _transform_barycentric_cartesian(): + """ + Returns the transformation matrix from barycentric to Cartesian + coordinates and conversely. + """ + # reference triangle + tri_verts = np.array([[0.5, np.sqrt(3) / 2], [0, 0], [1, 0]]) + M = np.array([tri_verts[:, 0], tri_verts[:, 1], np.ones(3)]) + return M, np.linalg.inv(M) + + +def _prepare_barycentric_coord(b_coords): + """ + Check ternary coordinates and return the right barycentric coordinates. + """ + if not isinstance(b_coords, (list, np.ndarray)): + raise ValueError( + "Data should be either an array of shape (n,m)," + "or a list of n m-lists, m=2 or 3" + ) + b_coords = np.asarray(b_coords) + if b_coords.shape[0] not in (2, 3): + raise ValueError( + "A point should have 2 (a, b) or 3 (a, b, c)barycentric coordinates" + ) + if ( + (len(b_coords) == 3) + and not np.allclose(b_coords.sum(axis=0), 1, rtol=0.01) + and not np.allclose(b_coords.sum(axis=0), 100, rtol=0.01) + ): + msg = "The sum of coordinates should be 1 or 100 for all data points" + raise ValueError(msg) + + if len(b_coords) == 2: + A, B = b_coords + C = 1 - (A + B) + else: + A, B, C = b_coords / b_coords.sum(axis=0) + if np.any(np.stack((A, B, C)) < 0): + raise ValueError("Barycentric coordinates should be positive.") + return np.stack((A, B, C)) + + +def _compute_grid(coordinates, values, interp_mode="ilr"): + """ + Transform data points with Cartesian or ILR mapping, then Compute + interpolation on a regular grid. + + Parameters + ========== + + coordinates : array-like + Barycentric coordinates of data points. + values : 1-d array-like + Data points, field to be represented as contours. + interp_mode : 'ilr' (default) or 'cartesian' + Defines how data are interpolated to compute contours. + """ + if interp_mode == "cartesian": + M, invM = _transform_barycentric_cartesian() + coord_points = np.einsum("ik, kj -> ij", M, coordinates) + elif interp_mode == "ilr": + coordinates = _replace_zero_coords(coordinates) + coord_points = _ilr_transform(coordinates) + else: + raise ValueError("interp_mode should be cartesian or ilr") + xx, yy = coord_points[:2] + x_min, x_max = xx.min(), xx.max() + y_min, y_max = yy.min(), yy.max() + n_interp = max(200, int(np.sqrt(len(values)))) + gr_x = np.linspace(x_min, x_max, n_interp) + gr_y = np.linspace(y_min, y_max, n_interp) + grid_x, grid_y = np.meshgrid(gr_x, gr_y) + # We use cubic interpolation, except outside of the convex hull + # of data points where we use nearest neighbor values. + grid_z = scipy_interp.griddata( + coord_points[:2].T, values, (grid_x, grid_y), method="cubic" + ) + return grid_z, gr_x, gr_y + + +# ----------------------- Contour traces ---------------------- + + +def _polygon_area(x, y): + return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + +def _colors(ncontours, colormap=None): + """ + Return a list of ``ncontours`` colors from the ``colormap`` colorscale. + """ + if colormap in clrs.PLOTLY_SCALES.keys(): + cmap = clrs.PLOTLY_SCALES[colormap] + else: + raise exceptions.PlotlyError( + "Colorscale must be a valid Plotly Colorscale." + "The available colorscale names are {}".format(clrs.PLOTLY_SCALES.keys()) + ) + values = np.linspace(0, 1, ncontours) + vals_cmap = np.array([pair[0] for pair in cmap]) + cols = np.array([pair[1] for pair in cmap]) + inds = np.searchsorted(vals_cmap, values) + if "#" in cols[0]: # for Viridis + cols = [clrs.label_rgb(clrs.hex_to_rgb(col)) for col in cols] + + colors = [cols[0]] + for ind, val in zip(inds[1:], values[1:]): + val1, val2 = vals_cmap[ind - 1], vals_cmap[ind] + interm = (val - val1) / (val2 - val1) + col = clrs.find_intermediate_color( + cols[ind - 1], cols[ind], interm, colortype="rgb" + ) + colors.append(col) + return colors + + +def _is_invalid_contour(x, y): + """ + Utility function for _contour_trace + + Contours with an area of the order as 1 pixel are considered spurious. + """ + too_small = np.all(np.abs(x - x[0]) < 2) and np.all(np.abs(y - y[0]) < 2) + return too_small + + +def _extract_contours(im, values, colors): + """ + Utility function for _contour_trace. + + In ``im`` only one part of the domain has valid values (corresponding + to a subdomain where barycentric coordinates are well defined). When + computing contours, we need to assign values outside of this domain. + We can choose a value either smaller than all the values inside the + valid domain, or larger. This value must be chose with caution so that + no spurious contours are added. For example, if the boundary of the valid + domain has large values and the outer value is set to a small one, all + intermediate contours will be added at the boundary. + + Therefore, we compute the two sets of contours (with an outer value + smaller of larger than all values in the valid domain), and choose + the value resulting in a smaller total number of contours. There might + be a faster way to do this, but it works... + """ + mask_nan = np.isnan(im) + im_min, im_max = ( + im[np.logical_not(mask_nan)].min(), + im[np.logical_not(mask_nan)].max(), + ) + zz_min = np.copy(im) + zz_min[mask_nan] = 2 * im_min + zz_max = np.copy(im) + zz_max[mask_nan] = 2 * im_max + all_contours1, all_values1, all_areas1, all_colors1 = [], [], [], [] + all_contours2, all_values2, all_areas2, all_colors2 = [], [], [], [] + for i, val in enumerate(values): + contour_level1 = measure.find_contours(zz_min, val) + contour_level2 = measure.find_contours(zz_max, val) + all_contours1.extend(contour_level1) + all_contours2.extend(contour_level2) + all_values1.extend([val] * len(contour_level1)) + all_values2.extend([val] * len(contour_level2)) + all_areas1.extend( + [_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level1] + ) + all_areas2.extend( + [_polygon_area(contour.T[1], contour.T[0]) for contour in contour_level2] + ) + all_colors1.extend([colors[i]] * len(contour_level1)) + all_colors2.extend([colors[i]] * len(contour_level2)) + if len(all_contours1) <= len(all_contours2): + return all_contours1, all_values1, all_areas1, all_colors1 + else: + return all_contours2, all_values2, all_areas2, all_colors2 + + +def _add_outer_contour( + all_contours, + all_values, + all_areas, + all_colors, + values, + val_outer, + v_min, + v_max, + colors, + color_min, + color_max, +): + """ + Utility function for _contour_trace + + Adds the background color to fill gaps outside of computed contours. + + To compute the background color, the color of the contour with largest + area (``val_outer``) is used. As background color, we choose the next + color value in the direction of the extrema of the colormap. + + Then we add information for the outer contour for the different lists + provided as arguments. + + A discrete colormap with all used colors is also returned (to be used + by colorscale trace). + """ + # The exact value of outer contour is not used when defining the trace + outer_contour = 20 * np.array([[0, 0, 1], [0, 1, 0.5]]).T + all_contours = [outer_contour] + all_contours + delta_values = np.diff(values)[0] + values = np.concatenate( + ([values[0] - delta_values], values, [values[-1] + delta_values]) + ) + colors = np.concatenate(([color_min], colors, [color_max])) + index = np.nonzero(values == val_outer)[0][0] + if index < len(values) / 2: + index -= 1 + else: + index += 1 + all_colors = [colors[index]] + all_colors + all_values = [values[index]] + all_values + all_areas = [0] + all_areas + used_colors = [color for color in colors if color in all_colors] + # Define discrete colorscale + color_number = len(used_colors) + scale = np.linspace(0, 1, color_number + 1) + discrete_cm = [] + for i, color in enumerate(used_colors): + discrete_cm.append([scale[i], used_colors[i]]) + discrete_cm.append([scale[i + 1], used_colors[i]]) + discrete_cm.append([scale[color_number], used_colors[color_number - 1]]) + + return all_contours, all_values, all_areas, all_colors, discrete_cm + + +def _contour_trace( + x, + y, + z, + ncontours=None, + colorscale="Electric", + linecolor="rgb(150,150,150)", + interp_mode="llr", + coloring=None, + v_min=0, + v_max=1, +): + """ + Contour trace in Cartesian coordinates. + + Parameters + ========== + + x, y : array-like + Cartesian coordinates + z : array-like + Field to be represented as contours. + ncontours : int or None + Number of contours to display (determined automatically if None). + colorscale : None or str (Plotly colormap) + colorscale of the contours. + linecolor : rgb color + Color used for lines. If ``colorscale`` is not None, line colors are + determined from ``colorscale`` instead. + interp_mode : 'ilr' (default) or 'cartesian' + Defines how data are interpolated to compute contours. If 'irl', + ILR (Isometric Log-Ratio) of compositional data is performed. If + 'cartesian', contours are determined in Cartesian space. + coloring : None or 'lines' + How to display contour. Filled contours if None, lines if ``lines``. + vmin, vmax : float + Bounds of interval of values used for the colorspace + + Notes + ===== + """ + # Prepare colors + # We do not take extrema, for example for one single contour + # the color will be the middle point of the colormap + colors = _colors(ncontours + 2, colorscale) + # Values used for contours, extrema are not used + # For example for a binary array [0, 1], the value of + # the contour for ncontours=1 is 0.5. + values = np.linspace(v_min, v_max, ncontours + 2) + color_min, color_max = colors[0], colors[-1] + colors = colors[1:-1] + values = values[1:-1] + + # Color of line contours + if linecolor is None: + linecolor = "rgb(150, 150, 150)" + else: + colors = [linecolor] * ncontours + + # Retrieve all contours + all_contours, all_values, all_areas, all_colors = _extract_contours( + z, values, colors + ) + + # Now sort contours by decreasing area + order = np.argsort(all_areas)[::-1] + + # Add outer contour + all_contours, all_values, all_areas, all_colors, discrete_cm = _add_outer_contour( + all_contours, + all_values, + all_areas, + all_colors, + values, + all_values[order[0]], + v_min, + v_max, + colors, + color_min, + color_max, + ) + order = np.concatenate(([0], order + 1)) + + # Compute traces, in the order of decreasing area + traces = [] + M, invM = _transform_barycentric_cartesian() + dx = (x.max() - x.min()) / x.size + dy = (y.max() - y.min()) / y.size + for index in order: + y_contour, x_contour = all_contours[index].T + val = all_values[index] + if interp_mode == "cartesian": + bar_coords = np.dot( + invM, + np.stack((dx * x_contour, dy * y_contour, np.ones(x_contour.shape))), + ) + elif interp_mode == "ilr": + bar_coords = _ilr_inverse( + np.stack((dx * x_contour + x.min(), dy * y_contour + y.min())) + ) + if index == 0: # outer triangle + a = np.array([1, 0, 0]) + b = np.array([0, 1, 0]) + c = np.array([0, 0, 1]) + else: + a, b, c = bar_coords + if _is_invalid_contour(x_contour, y_contour): + continue + + _col = all_colors[index] if coloring == "lines" else linecolor + trace = dict( + type="scatterternary", + a=a, + b=b, + c=c, + mode="lines", + line=dict(color=_col, shape="spline", width=1), + fill="toself", + fillcolor=all_colors[index], + showlegend=True, + hoverinfo="skip", + name="%.3f" % val, + ) + if coloring == "lines": + trace["fill"] = None + traces.append(trace) + + return traces, discrete_cm + + +# -------------------- Figure Factory for ternary contour ------------- + + +def create_ternary_contour( + coordinates, + values, + pole_labels=["a", "b", "c"], + width=500, + height=500, + ncontours=None, + showscale=False, + coloring=None, + colorscale="Bluered", + linecolor=None, + title=None, + interp_mode="ilr", + showmarkers=False, +): + """ + Ternary contour plot. + + Parameters + ---------- + + coordinates : list or ndarray + Barycentric coordinates of shape (2, N) or (3, N) where N is the + number of data points. The sum of the 3 coordinates is expected + to be 1 for all data points. + values : array-like + Data points of field to be represented as contours. + pole_labels : str, default ['a', 'b', 'c'] + Names of the three poles of the triangle. + width : int + Figure width. + height : int + Figure height. + ncontours : int or None + Number of contours to display (determined automatically if None). + showscale : bool, default False + If True, a colorbar showing the color scale is displayed. + coloring : None or 'lines' + How to display contour. Filled contours if None, lines if ``lines``. + colorscale : None or str (Plotly colormap) + colorscale of the contours. + linecolor : None or rgb color + Color used for lines. ``colorscale`` has to be set to None, otherwise + line colors are determined from ``colorscale``. + title : str or None + Title of ternary plot + interp_mode : 'ilr' (default) or 'cartesian' + Defines how data are interpolated to compute contours. If 'irl', + ILR (Isometric Log-Ratio) of compositional data is performed. If + 'cartesian', contours are determined in Cartesian space. + showmarkers : bool, default False + If True, markers corresponding to input compositional points are + superimposed on contours, using the same colorscale. + + Examples + ======== + + Example 1: ternary contour plot with filled contours + + >>> import plotly.figure_factory as ff + >>> import numpy as np + >>> # Define coordinates + >>> a, b = np.mgrid[0:1:20j, 0:1:20j] + >>> mask = a + b <= 1 + >>> a = a[mask].ravel() + >>> b = b[mask].ravel() + >>> c = 1 - a - b + >>> # Values to be displayed as contours + >>> z = a * b * c + >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z) + >>> fig.show() + + It is also possible to give only two barycentric coordinates for each + point, since the sum of the three coordinates is one: + + >>> fig = ff.create_ternary_contour(np.stack((a, b)), z) + + + Example 2: ternary contour plot with line contours + + >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines') + + Example 3: customize number of contours + + >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, ncontours=8) + + Example 4: superimpose contour plot and original data as markers + + >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, coloring='lines', + ... showmarkers=True) + + Example 5: customize title and pole labels + + >>> fig = ff.create_ternary_contour(np.stack((a, b, c)), z, + ... title='Ternary plot', + ... pole_labels=['clay', 'quartz', 'fledspar']) + """ + if scipy_interp is None: + raise ImportError( + """\ + The create_ternary_contour figure factory requires the scipy package""" + ) + sk_measure = optional_imports.get_module("skimage") + if sk_measure is None: + raise ImportError( + """\ + The create_ternary_contour figure factory requires the scikit-image + package""" + ) + if colorscale is None: + showscale = False + if ncontours is None: + ncontours = 5 + coordinates = _prepare_barycentric_coord(coordinates) + v_min, v_max = values.min(), values.max() + grid_z, gr_x, gr_y = _compute_grid(coordinates, values, interp_mode=interp_mode) + + layout = _ternary_layout( + pole_labels=pole_labels, width=width, height=height, title=title + ) + + contour_trace, discrete_cm = _contour_trace( + gr_x, + gr_y, + grid_z, + ncontours=ncontours, + colorscale=colorscale, + linecolor=linecolor, + interp_mode=interp_mode, + coloring=coloring, + v_min=v_min, + v_max=v_max, + ) + + fig = go.Figure(data=contour_trace, layout=layout) + + opacity = 1 if showmarkers else 0 + a, b, c = coordinates + hovertemplate = ( + pole_labels[0] + + ": %{a:.3f}<br>" + + pole_labels[1] + + ": %{b:.3f}<br>" + + pole_labels[2] + + ": %{c:.3f}<br>" + "z: %{marker.color:.3f}<extra></extra>" + ) + + fig.add_scatterternary( + a=a, + b=b, + c=c, + mode="markers", + marker={ + "color": values, + "colorscale": colorscale, + "line": {"color": "rgb(120, 120, 120)", "width": int(coloring != "lines")}, + }, + opacity=opacity, + hovertemplate=hovertemplate, + ) + if showscale: + if not showmarkers: + colorscale = discrete_cm + colorbar = dict( + { + "type": "scatterternary", + "a": [None], + "b": [None], + "c": [None], + "marker": { + "cmin": values.min(), + "cmax": values.max(), + "colorscale": colorscale, + "showscale": True, + }, + "mode": "markers", + } + ) + fig.add_trace(colorbar) + + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py new file mode 100644 index 0000000..f935292 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_trisurf.py @@ -0,0 +1,509 @@ +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.graph_objs import graph_objs + +np = optional_imports.get_module("numpy") + + +def map_face2color(face, colormap, scale, vmin, vmax): + """ + Normalize facecolor values by vmin/vmax and return rgb-color strings + + This function takes a tuple color along with a colormap and a minimum + (vmin) and maximum (vmax) range of possible mean distances for the + given parametrized surface. It returns an rgb color based on the mean + distance between vmin and vmax + + """ + if vmin >= vmax: + raise exceptions.PlotlyError( + "Incorrect relation between vmin " + "and vmax. The vmin value cannot be " + "bigger than or equal to the value " + "of vmax." + ) + if len(colormap) == 1: + # color each triangle face with the same color in colormap + face_color = colormap[0] + face_color = clrs.convert_to_RGB_255(face_color) + face_color = clrs.label_rgb(face_color) + return face_color + if face == vmax: + # pick last color in colormap + face_color = colormap[-1] + face_color = clrs.convert_to_RGB_255(face_color) + face_color = clrs.label_rgb(face_color) + return face_color + else: + if scale is None: + # find the normalized distance t of a triangle face between + # vmin and vmax where the distance is between 0 and 1 + t = (face - vmin) / float((vmax - vmin)) + low_color_index = int(t / (1.0 / (len(colormap) - 1))) + + face_color = clrs.find_intermediate_color( + colormap[low_color_index], + colormap[low_color_index + 1], + t * (len(colormap) - 1) - low_color_index, + ) + + face_color = clrs.convert_to_RGB_255(face_color) + face_color = clrs.label_rgb(face_color) + else: + # find the face color for a non-linearly interpolated scale + t = (face - vmin) / float((vmax - vmin)) + + low_color_index = 0 + for k in range(len(scale) - 1): + if scale[k] <= t < scale[k + 1]: + break + low_color_index += 1 + + low_scale_val = scale[low_color_index] + high_scale_val = scale[low_color_index + 1] + + face_color = clrs.find_intermediate_color( + colormap[low_color_index], + colormap[low_color_index + 1], + (t - low_scale_val) / (high_scale_val - low_scale_val), + ) + + face_color = clrs.convert_to_RGB_255(face_color) + face_color = clrs.label_rgb(face_color) + return face_color + + +def trisurf( + x, + y, + z, + simplices, + show_colorbar, + edges_color, + scale, + colormap=None, + color_func=None, + plot_edges=False, + x_edge=None, + y_edge=None, + z_edge=None, + facecolor=None, +): + """ + Refer to FigureFactory.create_trisurf() for docstring + """ + # numpy import check + if not np: + raise ImportError("FigureFactory._trisurf() requires numpy imported.") + points3D = np.vstack((x, y, z)).T + simplices = np.atleast_2d(simplices) + + # vertices of the surface triangles + tri_vertices = points3D[simplices] + + # Define colors for the triangle faces + if color_func is None: + # mean values of z-coordinates of triangle vertices + mean_dists = tri_vertices[:, :, 2].mean(-1) + elif isinstance(color_func, (list, np.ndarray)): + # Pre-computed list / array of values to map onto color + if len(color_func) != len(simplices): + raise ValueError( + "If color_func is a list/array, it must " + "be the same length as simplices." + ) + + # convert all colors in color_func to rgb + for index in range(len(color_func)): + if isinstance(color_func[index], str): + if "#" in color_func[index]: + foo = clrs.hex_to_rgb(color_func[index]) + color_func[index] = clrs.label_rgb(foo) + + if isinstance(color_func[index], tuple): + foo = clrs.convert_to_RGB_255(color_func[index]) + color_func[index] = clrs.label_rgb(foo) + + mean_dists = np.asarray(color_func) + else: + # apply user inputted function to calculate + # custom coloring for triangle vertices + mean_dists = [] + for triangle in tri_vertices: + dists = [] + for vertex in triangle: + dist = color_func(vertex[0], vertex[1], vertex[2]) + dists.append(dist) + mean_dists.append(np.mean(dists)) + mean_dists = np.asarray(mean_dists) + + # Check if facecolors are already strings and can be skipped + if isinstance(mean_dists[0], str): + facecolor = mean_dists + else: + min_mean_dists = np.min(mean_dists) + max_mean_dists = np.max(mean_dists) + + if facecolor is None: + facecolor = [] + for index in range(len(mean_dists)): + color = map_face2color( + mean_dists[index], colormap, scale, min_mean_dists, max_mean_dists + ) + facecolor.append(color) + + # Make sure facecolor is a list so output is consistent across Pythons + facecolor = np.asarray(facecolor) + ii, jj, kk = simplices.T + + triangles = graph_objs.Mesh3d( + x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name="" + ) + + mean_dists_are_numbers = not isinstance(mean_dists[0], str) + + if mean_dists_are_numbers and show_colorbar is True: + # make a colorscale from the colors + colorscale = clrs.make_colorscale(colormap, scale) + colorscale = clrs.convert_colorscale_to_rgb(colorscale) + + colorbar = graph_objs.Scatter3d( + x=x[:1], + y=y[:1], + z=z[:1], + mode="markers", + marker=dict( + size=0.1, + color=[min_mean_dists, max_mean_dists], + colorscale=colorscale, + showscale=True, + ), + hoverinfo="none", + showlegend=False, + ) + + # the triangle sides are not plotted + if plot_edges is False: + if mean_dists_are_numbers and show_colorbar is True: + return [triangles, colorbar] + else: + return [triangles] + + # define the lists x_edge, y_edge and z_edge, of x, y, resp z + # coordinates of edge end points for each triangle + # None separates data corresponding to two consecutive triangles + is_none = [ii is None for ii in [x_edge, y_edge, z_edge]] + if any(is_none): + if not all(is_none): + raise ValueError( + "If any (x_edge, y_edge, z_edge) is None, all must be None" + ) + else: + x_edge = [] + y_edge = [] + z_edge = [] + + # Pull indices we care about, then add a None column to separate tris + ixs_triangles = [0, 1, 2, 0] + pull_edges = tri_vertices[:, ixs_triangles, :] + x_edge_pull = np.hstack( + [pull_edges[:, :, 0], np.tile(None, [pull_edges.shape[0], 1])] + ) + y_edge_pull = np.hstack( + [pull_edges[:, :, 1], np.tile(None, [pull_edges.shape[0], 1])] + ) + z_edge_pull = np.hstack( + [pull_edges[:, :, 2], np.tile(None, [pull_edges.shape[0], 1])] + ) + + # Now unravel the edges into a 1-d vector for plotting + x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]]) + y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]]) + z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]]) + + if not (len(x_edge) == len(y_edge) == len(z_edge)): + raise exceptions.PlotlyError( + "The lengths of x_edge, y_edge and z_edge are not the same." + ) + + # define the lines for plotting + lines = graph_objs.Scatter3d( + x=x_edge, + y=y_edge, + z=z_edge, + mode="lines", + line=graph_objs.scatter3d.Line(color=edges_color, width=1.5), + showlegend=False, + ) + + if mean_dists_are_numbers and show_colorbar is True: + return [triangles, lines, colorbar] + else: + return [triangles, lines] + + +def create_trisurf( + x, + y, + z, + simplices, + colormap=None, + show_colorbar=True, + scale=None, + color_func=None, + title="Trisurf Plot", + plot_edges=True, + showbackground=True, + backgroundcolor="rgb(230, 230, 230)", + gridcolor="rgb(255, 255, 255)", + zerolinecolor="rgb(255, 255, 255)", + edges_color="rgb(50, 50, 50)", + height=800, + width=800, + aspectratio=None, +): + """ + Returns figure for a triangulated surface plot + + :param (array) x: data values of x in a 1D array + :param (array) y: data values of y in a 1D array + :param (array) z: data values of z in a 1D array + :param (array) simplices: an array of shape (ntri, 3) where ntri is + the number of triangles in the triangularization. Each row of the + array contains the indicies of the verticies of each triangle + :param (str|tuple|list) colormap: either a plotly scale name, an rgb + or hex color, a color tuple or a list of colors. An rgb color is + of the form 'rgb(x, y, z)' where x, y, z belong to the interval + [0, 255] and a color tuple is a tuple of the form (a, b, c) where + a, b and c belong to [0, 1]. If colormap is a list, it must + contain the valid color types aforementioned as its members + :param (bool) show_colorbar: determines if colorbar is visible + :param (list|array) scale: sets the scale values to be used if a non- + linearly interpolated colormap is desired. If left as None, a + linear interpolation between the colors will be excecuted + :param (function|list) color_func: The parameter that determines the + coloring of the surface. Takes either a function with 3 arguments + x, y, z or a list/array of color values the same length as + simplices. If None, coloring will only depend on the z axis + :param (str) title: title of the plot + :param (bool) plot_edges: determines if the triangles on the trisurf + are visible + :param (bool) showbackground: makes background in plot visible + :param (str) backgroundcolor: color of background. Takes a string of + the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive + :param (str) gridcolor: color of the gridlines besides the axes. Takes + a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 + inclusive + :param (str) zerolinecolor: color of the axes. Takes a string of the + form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive + :param (str) edges_color: color of the edges, if plot_edges is True + :param (int|float) height: the height of the plot (in pixels) + :param (int|float) width: the width of the plot (in pixels) + :param (dict) aspectratio: a dictionary of the aspect ratio values for + the x, y and z axes. 'x', 'y' and 'z' take (int|float) values + + Example 1: Sphere + + >>> # Necessary Imports for Trisurf + >>> import numpy as np + >>> from scipy.spatial import Delaunay + + >>> from plotly.figure_factory import create_trisurf + >>> from plotly.graph_objs import graph_objs + + >>> # Make data for plot + >>> u = np.linspace(0, 2*np.pi, 20) + >>> v = np.linspace(0, np.pi, 20) + >>> u,v = np.meshgrid(u,v) + >>> u = u.flatten() + >>> v = v.flatten() + + >>> x = np.sin(v)*np.cos(u) + >>> y = np.sin(v)*np.sin(u) + >>> z = np.cos(v) + + >>> points2D = np.vstack([u,v]).T + >>> tri = Delaunay(points2D) + >>> simplices = tri.simplices + + >>> # Create a figure + >>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow", + ... simplices=simplices) + + Example 2: Torus + + >>> # Necessary Imports for Trisurf + >>> import numpy as np + >>> from scipy.spatial import Delaunay + + >>> from plotly.figure_factory import create_trisurf + >>> from plotly.graph_objs import graph_objs + + >>> # Make data for plot + >>> u = np.linspace(0, 2*np.pi, 20) + >>> v = np.linspace(0, 2*np.pi, 20) + >>> u,v = np.meshgrid(u,v) + >>> u = u.flatten() + >>> v = v.flatten() + + >>> x = (3 + (np.cos(v)))*np.cos(u) + >>> y = (3 + (np.cos(v)))*np.sin(u) + >>> z = np.sin(v) + + >>> points2D = np.vstack([u,v]).T + >>> tri = Delaunay(points2D) + >>> simplices = tri.simplices + + >>> # Create a figure + >>> fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis", + ... simplices=simplices) + + Example 3: Mobius Band + + >>> # Necessary Imports for Trisurf + >>> import numpy as np + >>> from scipy.spatial import Delaunay + + >>> from plotly.figure_factory import create_trisurf + >>> from plotly.graph_objs import graph_objs + + >>> # Make data for plot + >>> u = np.linspace(0, 2*np.pi, 24) + >>> v = np.linspace(-1, 1, 8) + >>> u,v = np.meshgrid(u,v) + >>> u = u.flatten() + >>> v = v.flatten() + + >>> tp = 1 + 0.5*v*np.cos(u/2.) + >>> x = tp*np.cos(u) + >>> y = tp*np.sin(u) + >>> z = 0.5*v*np.sin(u/2.) + + >>> points2D = np.vstack([u,v]).T + >>> tri = Delaunay(points2D) + >>> simplices = tri.simplices + + >>> # Create a figure + >>> fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)], + ... simplices=simplices) + + Example 4: Using a Custom Colormap Function with Light Cone + + >>> # Necessary Imports for Trisurf + >>> import numpy as np + >>> from scipy.spatial import Delaunay + + >>> from plotly.figure_factory import create_trisurf + >>> from plotly.graph_objs import graph_objs + + >>> # Make data for plot + >>> u=np.linspace(-np.pi, np.pi, 30) + >>> v=np.linspace(-np.pi, np.pi, 30) + >>> u,v=np.meshgrid(u,v) + >>> u=u.flatten() + >>> v=v.flatten() + + >>> x = u + >>> y = u*np.cos(v) + >>> z = u*np.sin(v) + + >>> points2D = np.vstack([u,v]).T + >>> tri = Delaunay(points2D) + >>> simplices = tri.simplices + + >>> # Define distance function + >>> def dist_origin(x, y, z): + ... return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2) + + >>> # Create a figure + >>> fig1 = create_trisurf(x=x, y=y, z=z, + ... colormap=['#FFFFFF', '#E4FFFE', + ... '#A4F6F9', '#FF99FE', + ... '#BA52ED'], + ... scale=[0, 0.6, 0.71, 0.89, 1], + ... simplices=simplices, + ... color_func=dist_origin) + + Example 5: Enter color_func as a list of colors + + >>> # Necessary Imports for Trisurf + >>> import numpy as np + >>> from scipy.spatial import Delaunay + >>> import random + + >>> from plotly.figure_factory import create_trisurf + >>> from plotly.graph_objs import graph_objs + + >>> # Make data for plot + >>> u=np.linspace(-np.pi, np.pi, 30) + >>> v=np.linspace(-np.pi, np.pi, 30) + >>> u,v=np.meshgrid(u,v) + >>> u=u.flatten() + >>> v=v.flatten() + + >>> x = u + >>> y = u*np.cos(v) + >>> z = u*np.sin(v) + + >>> points2D = np.vstack([u,v]).T + >>> tri = Delaunay(points2D) + >>> simplices = tri.simplices + + + >>> colors = [] + >>> color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd'] + + >>> for index in range(len(simplices)): + ... colors.append(random.choice(color_choices)) + + >>> fig = create_trisurf( + ... x, y, z, simplices, + ... color_func=colors, + ... show_colorbar=True, + ... edges_color='rgb(2, 85, 180)', + ... title=' Modern Art' + ... ) + """ + if aspectratio is None: + aspectratio = {"x": 1, "y": 1, "z": 1} + + # Validate colormap + clrs.validate_colors(colormap) + colormap, scale = clrs.convert_colors_to_same_type( + colormap, colortype="tuple", return_default_colors=True, scale=scale + ) + + data1 = trisurf( + x, + y, + z, + simplices, + show_colorbar=show_colorbar, + color_func=color_func, + colormap=colormap, + scale=scale, + edges_color=edges_color, + plot_edges=plot_edges, + ) + + axis = dict( + showbackground=showbackground, + backgroundcolor=backgroundcolor, + gridcolor=gridcolor, + zerolinecolor=zerolinecolor, + ) + layout = graph_objs.Layout( + title=title, + width=width, + height=height, + scene=graph_objs.layout.Scene( + xaxis=graph_objs.layout.scene.XAxis(**axis), + yaxis=graph_objs.layout.scene.YAxis(**axis), + zaxis=graph_objs.layout.scene.ZAxis(**axis), + aspectratio=dict( + x=aspectratio["x"], y=aspectratio["y"], z=aspectratio["z"] + ), + ), + ) + + return graph_objs.Figure(data=data1, layout=layout) diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py new file mode 100644 index 0000000..55924e6 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/_violin.py @@ -0,0 +1,704 @@ +from numbers import Number + +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.graph_objs import graph_objs +from plotly.subplots import make_subplots + +pd = optional_imports.get_module("pandas") +np = optional_imports.get_module("numpy") +scipy_stats = optional_imports.get_module("scipy.stats") + + +def calc_stats(data): + """ + Calculate statistics for use in violin plot. + """ + x = np.asarray(data, float) + vals_min = np.min(x) + vals_max = np.max(x) + q2 = np.percentile(x, 50, interpolation="linear") + q1 = np.percentile(x, 25, interpolation="lower") + q3 = np.percentile(x, 75, interpolation="higher") + iqr = q3 - q1 + whisker_dist = 1.5 * iqr + + # in order to prevent drawing whiskers outside the interval + # of data one defines the whisker positions as: + d1 = np.min(x[x >= (q1 - whisker_dist)]) + d2 = np.max(x[x <= (q3 + whisker_dist)]) + return { + "min": vals_min, + "max": vals_max, + "q1": q1, + "q2": q2, + "q3": q3, + "d1": d1, + "d2": d2, + } + + +def make_half_violin(x, y, fillcolor="#1f77b4", linecolor="rgb(0, 0, 0)"): + """ + Produces a sideways probability distribution fig violin plot. + """ + text = [ + "(pdf(y), y)=(" + "{:0.2f}".format(x[i]) + ", " + "{:0.2f}".format(y[i]) + ")" + for i in range(len(x)) + ] + + return graph_objs.Scatter( + x=x, + y=y, + mode="lines", + name="", + text=text, + fill="tonextx", + fillcolor=fillcolor, + line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape="spline"), + hoverinfo="text", + opacity=0.5, + ) + + +def make_violin_rugplot(vals, pdf_max, distance, color="#1f77b4"): + """ + Returns a rugplot fig for a violin plot. + """ + return graph_objs.Scatter( + y=vals, + x=[-pdf_max - distance] * len(vals), + marker=graph_objs.scatter.Marker(color=color, symbol="line-ew-open"), + mode="markers", + name="", + showlegend=False, + hoverinfo="y", + ) + + +def make_non_outlier_interval(d1, d2): + """ + Returns the scatterplot fig of most of a violin plot. + """ + return graph_objs.Scatter( + x=[0, 0], + y=[d1, d2], + name="", + mode="lines", + line=graph_objs.scatter.Line(width=1.5, color="rgb(0,0,0)"), + ) + + +def make_quartiles(q1, q3): + """ + Makes the upper and lower quartiles for a violin plot. + """ + return graph_objs.Scatter( + x=[0, 0], + y=[q1, q3], + text=[ + "lower-quartile: " + "{:0.2f}".format(q1), + "upper-quartile: " + "{:0.2f}".format(q3), + ], + mode="lines", + line=graph_objs.scatter.Line(width=4, color="rgb(0,0,0)"), + hoverinfo="text", + ) + + +def make_median(q2): + """ + Formats the 'median' hovertext for a violin plot. + """ + return graph_objs.Scatter( + x=[0], + y=[q2], + text=["median: " + "{:0.2f}".format(q2)], + mode="markers", + marker=dict(symbol="square", color="rgb(255,255,255)"), + hoverinfo="text", + ) + + +def make_XAxis(xaxis_title, xaxis_range): + """ + Makes the x-axis for a violin plot. + """ + xaxis = graph_objs.layout.XAxis( + title=xaxis_title, + range=xaxis_range, + showgrid=False, + zeroline=False, + showline=False, + mirror=False, + ticks="", + showticklabels=False, + ) + return xaxis + + +def make_YAxis(yaxis_title): + """ + Makes the y-axis for a violin plot. + """ + yaxis = graph_objs.layout.YAxis( + title=yaxis_title, + showticklabels=True, + autorange=True, + ticklen=4, + showline=True, + zeroline=False, + showgrid=False, + mirror=False, + ) + return yaxis + + +def violinplot(vals, fillcolor="#1f77b4", rugplot=True): + """ + Refer to FigureFactory.create_violin() for docstring. + """ + vals = np.asarray(vals, float) + # summary statistics + vals_min = calc_stats(vals)["min"] + vals_max = calc_stats(vals)["max"] + q1 = calc_stats(vals)["q1"] + q2 = calc_stats(vals)["q2"] + q3 = calc_stats(vals)["q3"] + d1 = calc_stats(vals)["d1"] + d2 = calc_stats(vals)["d2"] + + # kernel density estimation of pdf + pdf = scipy_stats.gaussian_kde(vals) + # grid over the data interval + xx = np.linspace(vals_min, vals_max, 100) + # evaluate the pdf at the grid xx + yy = pdf(xx) + max_pdf = np.max(yy) + # distance from the violin plot to rugplot + distance = (2.0 * max_pdf) / 10 if rugplot else 0 + # range for x values in the plot + plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1] + plot_data = [ + make_half_violin(-yy, xx, fillcolor=fillcolor), + make_half_violin(yy, xx, fillcolor=fillcolor), + make_non_outlier_interval(d1, d2), + make_quartiles(q1, q3), + make_median(q2), + ] + if rugplot: + plot_data.append( + make_violin_rugplot(vals, max_pdf, distance=distance, color=fillcolor) + ) + return plot_data, plot_xrange + + +def violin_no_colorscale( + data, + data_header, + group_header, + colors, + use_colorscale, + group_stats, + rugplot, + sort, + height, + width, + title, +): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot without colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + if sort: + group_name.sort() + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots( + rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False + ) + color_index = 0 + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], float) + if color_index >= len(colors): + color_index = 0 + plot_data, plot_xrange = violinplot( + vals, fillcolor=colors[color_index], rugplot=rugplot + ) + for item in plot_data: + fig.append_trace(item, 1, k + 1) + color_index += 1 + + # add violin plot labels + fig["layout"].update( + {"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + + # set the sharey axis style + fig["layout"].update({"yaxis{}".format(1): make_YAxis("")}) + fig["layout"].update( + title=title, + showlegend=False, + hovermode="closest", + autosize=False, + height=height, + width=width, + ) + + return fig + + +def violin_colorscale( + data, + data_header, + group_header, + colors, + use_colorscale, + group_stats, + rugplot, + sort, + height, + width, + title, +): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot with colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + if sort: + group_name.sort() + + # make sure all group names are keys in group_stats + for group in group_name: + if group not in group_stats: + raise exceptions.PlotlyError( + "All values/groups in the index " + "column must be represented " + "as a key in group_stats." + ) + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots( + rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False + ) + + # prepare low and high color for colorscale + lowcolor = clrs.color_parser(colors[0], clrs.unlabel_rgb) + highcolor = clrs.color_parser(colors[1], clrs.unlabel_rgb) + + # find min and max values in group_stats + group_stats_values = [] + for key in group_stats: + group_stats_values.append(group_stats[key]) + + max_value = max(group_stats_values) + min_value = min(group_stats_values) + + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], float) + + # find intermediate color from colorscale + intermed = (group_stats[gr] - min_value) / (max_value - min_value) + intermed_color = clrs.find_intermediate_color(lowcolor, highcolor, intermed) + + plot_data, plot_xrange = violinplot( + vals, fillcolor="rgb{}".format(intermed_color), rugplot=rugplot + ) + for item in plot_data: + fig.append_trace(item, 1, k + 1) + fig["layout"].update( + {"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + # add colorbar to plot + trace_dummy = graph_objs.Scatter( + x=[0], + y=[0], + mode="markers", + marker=dict( + size=2, + cmin=min_value, + cmax=max_value, + colorscale=[[0, colors[0]], [1, colors[1]]], + showscale=True, + ), + showlegend=False, + ) + fig.append_trace(trace_dummy, 1, L) + + # set the sharey axis style + fig["layout"].update({"yaxis{}".format(1): make_YAxis("")}) + fig["layout"].update( + title=title, + showlegend=False, + hovermode="closest", + autosize=False, + height=height, + width=width, + ) + + return fig + + +def violin_dict( + data, + data_header, + group_header, + colors, + use_colorscale, + group_stats, + rugplot, + sort, + height, + width, + title, +): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot without colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + + if sort: + group_name.sort() + + # check if all group names appear in colors dict + for group in group_name: + if group not in colors: + raise exceptions.PlotlyError( + "If colors is a dictionary, all " + "the group names must appear as " + "keys in colors." + ) + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots( + rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=False + ) + + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], float) + plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr], rugplot=rugplot) + for item in plot_data: + fig.append_trace(item, 1, k + 1) + + # add violin plot labels + fig["layout"].update( + {"xaxis{}".format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + + # set the sharey axis style + fig["layout"].update({"yaxis{}".format(1): make_YAxis("")}) + fig["layout"].update( + title=title, + showlegend=False, + hovermode="closest", + autosize=False, + height=height, + width=width, + ) + + return fig + + +def create_violin( + data, + data_header=None, + group_header=None, + colors=None, + use_colorscale=False, + group_stats=None, + rugplot=True, + sort=False, + height=450, + width=600, + title="Violin and Rug Plot", +): + """ + **deprecated**, use instead the plotly.graph_objects trace + :class:`plotly.graph_objects.Violin`. + + :param (list|array) data: accepts either a list of numerical values, + a list of dictionaries all with identical keys and at least one + column of numeric values, or a pandas dataframe with at least one + column of numbers. + :param (str) data_header: the header of the data column to be used + from an inputted pandas dataframe. Not applicable if 'data' is + a list of numeric values. + :param (str) group_header: applicable if grouping data by a variable. + 'group_header' must be set to the name of the grouping variable. + :param (str|tuple|list|dict) colors: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colors is a list, it must contain valid color types as its + members. + :param (bool) use_colorscale: only applicable if grouping by another + variable. Will implement a colorscale based on the first 2 colors + of param colors. This means colors must be a list with at least 2 + colors in it (Plotly colorscales are accepted since they map to a + list of two rgb colors). Default = False + :param (dict) group_stats: a dictionary where each key is a unique + value from the group_header column in data. Each value must be a + number and will be used to color the violin plots if a colorscale + is being used. + :param (bool) rugplot: determines if a rugplot is draw on violin plot. + Default = True + :param (bool) sort: determines if violins are sorted + alphabetically (True) or by input order (False). Default = False + :param (float) height: the height of the violin plot. + :param (float) width: the width of the violin plot. + :param (str) title: the title of the violin plot. + + Example 1: Single Violin Plot + + >>> from plotly.figure_factory import create_violin + >>> import plotly.graph_objs as graph_objects + + >>> import numpy as np + >>> from scipy import stats + + >>> # create list of random values + >>> data_list = np.random.randn(100) + + >>> # create violin fig + >>> fig = create_violin(data_list, colors='#604d9e') + + >>> # plot + >>> fig.show() + + Example 2: Multiple Violin Plots with Qualitative Coloring + + >>> from plotly.figure_factory import create_violin + >>> import plotly.graph_objs as graph_objects + + >>> import numpy as np + >>> import pandas as pd + >>> from scipy import stats + + >>> # create dataframe + >>> np.random.seed(619517) + >>> Nr=250 + >>> y = np.random.randn(Nr) + >>> gr = np.random.choice(list("ABCDE"), Nr) + >>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] + + >>> for i, letter in enumerate("ABCDE"): + ... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] + >>> df = pd.DataFrame(dict(Score=y, Group=gr)) + + >>> # create violin fig + >>> fig = create_violin(df, data_header='Score', group_header='Group', + ... sort=True, height=600, width=1000) + + >>> # plot + >>> fig.show() + + Example 3: Violin Plots with Colorscale + + >>> from plotly.figure_factory import create_violin + >>> import plotly.graph_objs as graph_objects + + >>> import numpy as np + >>> import pandas as pd + >>> from scipy import stats + + >>> # create dataframe + >>> np.random.seed(619517) + >>> Nr=250 + >>> y = np.random.randn(Nr) + >>> gr = np.random.choice(list("ABCDE"), Nr) + >>> norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] + + >>> for i, letter in enumerate("ABCDE"): + ... y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] + >>> df = pd.DataFrame(dict(Score=y, Group=gr)) + + >>> # define header params + >>> data_header = 'Score' + >>> group_header = 'Group' + + >>> # make groupby object with pandas + >>> group_stats = {} + >>> groupby_data = df.groupby([group_header]) + + >>> for group in "ABCDE": + ... data_from_group = groupby_data.get_group(group)[data_header] + ... # take a stat of the grouped data + ... stat = np.median(data_from_group) + ... # add to dictionary + ... group_stats[group] = stat + + >>> # create violin fig + >>> fig = create_violin(df, data_header='Score', group_header='Group', + ... height=600, width=1000, use_colorscale=True, + ... group_stats=group_stats) + + >>> # plot + >>> fig.show() + """ + + # Validate colors + if isinstance(colors, dict): + valid_colors = clrs.validate_colors_dict(colors, "rgb") + else: + valid_colors = clrs.validate_colors(colors, "rgb") + + # validate data and choose plot type + if group_header is None: + if isinstance(data, list): + if len(data) <= 0: + raise exceptions.PlotlyError( + "If data is a list, it must be " + "nonempty and contain either " + "numbers or dictionaries." + ) + + if not all(isinstance(element, Number) for element in data): + raise exceptions.PlotlyError( + "If data is a list, it must contain only numbers." + ) + + if pd and isinstance(data, pd.core.frame.DataFrame): + if data_header is None: + raise exceptions.PlotlyError( + "data_header must be the " + "column name with the " + "desired numeric data for " + "the violin plot." + ) + + data = data[data_header].values.tolist() + + # call the plotting functions + plot_data, plot_xrange = violinplot( + data, fillcolor=valid_colors[0], rugplot=rugplot + ) + + layout = graph_objs.Layout( + title=title, + autosize=False, + font=graph_objs.layout.Font(size=11), + height=height, + showlegend=False, + width=width, + xaxis=make_XAxis("", plot_xrange), + yaxis=make_YAxis(""), + hovermode="closest", + ) + layout["yaxis"].update(dict(showline=False, showticklabels=False, ticks="")) + + fig = graph_objs.Figure(data=plot_data, layout=layout) + + return fig + + else: + if not isinstance(data, pd.core.frame.DataFrame): + raise exceptions.PlotlyError( + "Error. You must use a pandas " + "DataFrame if you are using a " + "group header." + ) + + if data_header is None: + raise exceptions.PlotlyError( + "data_header must be the column " + "name with the desired numeric " + "data for the violin plot." + ) + + if use_colorscale is False: + if isinstance(valid_colors, dict): + # validate colors dict choice below + fig = violin_dict( + data, + data_header, + group_header, + valid_colors, + use_colorscale, + group_stats, + rugplot, + sort, + height, + width, + title, + ) + return fig + else: + fig = violin_no_colorscale( + data, + data_header, + group_header, + valid_colors, + use_colorscale, + group_stats, + rugplot, + sort, + height, + width, + title, + ) + return fig + else: + if isinstance(valid_colors, dict): + raise exceptions.PlotlyError( + "The colors param cannot be " + "a dictionary if you are " + "using a colorscale." + ) + + if len(valid_colors) < 2: + raise exceptions.PlotlyError( + "colors must be a list with " + "at least 2 colors. A " + "Plotly scale is allowed." + ) + + if not isinstance(group_stats, dict): + raise exceptions.PlotlyError( + "Your group_stats param must be a dictionary." + ) + + fig = violin_colorscale( + data, + data_header, + group_header, + valid_colors, + use_colorscale, + group_stats, + rugplot, + sort, + height, + width, + title, + ) + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py b/venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py new file mode 100644 index 0000000..e20a319 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/figure_factory/utils.py @@ -0,0 +1,249 @@ +from collections.abc import Sequence + +from plotly import exceptions + + +def is_sequence(obj): + return isinstance(obj, Sequence) and not isinstance(obj, str) + + +def validate_index(index_vals): + """ + Validates if a list contains all numbers or all strings + + :raises: (PlotlyError) If there are any two items in the list whose + types differ + """ + from numbers import Number + + if isinstance(index_vals[0], Number): + if not all(isinstance(item, Number) for item in index_vals): + raise exceptions.PlotlyError( + "Error in indexing column. " + "Make sure all entries of each " + "column are all numbers or " + "all strings." + ) + + elif isinstance(index_vals[0], str): + if not all(isinstance(item, str) for item in index_vals): + raise exceptions.PlotlyError( + "Error in indexing column. " + "Make sure all entries of each " + "column are all numbers or " + "all strings." + ) + + +def validate_dataframe(array): + """ + Validates all strings or numbers in each dataframe column + + :raises: (PlotlyError) If there are any two items in any list whose + types differ + """ + from numbers import Number + + for vector in array: + if isinstance(vector[0], Number): + if not all(isinstance(item, Number) for item in vector): + raise exceptions.PlotlyError( + "Error in dataframe. " + "Make sure all entries of " + "each column are either " + "numbers or strings." + ) + elif isinstance(vector[0], str): + if not all(isinstance(item, str) for item in vector): + raise exceptions.PlotlyError( + "Error in dataframe. " + "Make sure all entries of " + "each column are either " + "numbers or strings." + ) + + +def validate_equal_length(*args): + """ + Validates that data lists or ndarrays are the same length. + + :raises: (PlotlyError) If any data lists are not the same length. + """ + length = len(args[0]) + if any(len(lst) != length for lst in args): + raise exceptions.PlotlyError( + "Oops! Your data lists or ndarrays should be the same length." + ) + + +def validate_positive_scalars(**kwargs): + """ + Validates that all values given in key/val pairs are positive. + + Accepts kwargs to improve Exception messages. + + :raises: (PlotlyError) If any value is < 0 or raises. + """ + for key, val in kwargs.items(): + try: + if val <= 0: + raise ValueError("{} must be > 0, got {}".format(key, val)) + except TypeError: + raise exceptions.PlotlyError("{} must be a number, got {}".format(key, val)) + + +def flatten(array): + """ + Uses list comprehension to flatten array + + :param (array): An iterable to flatten + :raises (PlotlyError): If iterable is not nested. + :rtype (list): The flattened list. + """ + try: + return [item for sublist in array for item in sublist] + except TypeError: + raise exceptions.PlotlyError( + "Your data array could not be " + "flattened! Make sure your data is " + "entered as lists or ndarrays!" + ) + + +def endpts_to_intervals(endpts): + """ + Returns a list of intervals for categorical colormaps + + Accepts a list or tuple of sequentially increasing numbers and returns + a list representation of the mathematical intervals with these numbers + as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]] + + :raises: (PlotlyError) If input is not a list or tuple + :raises: (PlotlyError) If the input contains a string + :raises: (PlotlyError) If any number does not increase after the + previous one in the sequence + """ + length = len(endpts) + # Check if endpts is a list or tuple + if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))): + raise exceptions.PlotlyError( + "The intervals_endpts argument must " + "be a list or tuple of a sequence " + "of increasing numbers." + ) + # Check if endpts contains only numbers + for item in endpts: + if isinstance(item, str): + raise exceptions.PlotlyError( + "The intervals_endpts argument " + "must be a list or tuple of a " + "sequence of increasing " + "numbers." + ) + # Check if numbers in endpts are increasing + for k in range(length - 1): + if endpts[k] >= endpts[k + 1]: + raise exceptions.PlotlyError( + "The intervals_endpts argument " + "must be a list or tuple of a " + "sequence of increasing " + "numbers." + ) + else: + intervals = [] + # add -inf to intervals + intervals.append([float("-inf"), endpts[0]]) + for k in range(length - 1): + interval = [] + interval.append(endpts[k]) + interval.append(endpts[k + 1]) + intervals.append(interval) + # add +inf to intervals + intervals.append([endpts[length - 1], float("inf")]) + return intervals + + +def annotation_dict_for_label( + text, + lane, + num_of_lanes, + subplot_spacing, + row_col="col", + flipped=True, + right_side=True, + text_color="#0f0f0f", +): + """ + Returns annotation dict for label of n labels of a 1xn or nx1 subplot. + + :param (str) text: the text for a label. + :param (int) lane: the label number for text. From 1 to n inclusive. + :param (int) num_of_lanes: the number 'n' of rows or columns in subplot. + :param (float) subplot_spacing: the value for the horizontal_spacing and + vertical_spacing params in your plotly.tools.make_subplots() call. + :param (str) row_col: choose whether labels are placed along rows or + columns. + :param (bool) flipped: flips text by 90 degrees. Text is printed + horizontally if set to True and row_col='row', or if False and + row_col='col'. + :param (bool) right_side: only applicable if row_col is set to 'row'. + :param (str) text_color: color of the text. + """ + temp = (1 - (num_of_lanes - 1) * subplot_spacing) / (num_of_lanes) + if not flipped: + xanchor = "center" + yanchor = "middle" + if row_col == "col": + x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp + y = 1.03 + textangle = 0 + elif row_col == "row": + y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp + x = 1.03 + textangle = 90 + else: + if row_col == "col": + xanchor = "center" + yanchor = "bottom" + x = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp + y = 1.0 + textangle = 270 + elif row_col == "row": + yanchor = "middle" + y = (lane - 1) * (temp + subplot_spacing) + 0.5 * temp + if right_side: + x = 1.0 + xanchor = "left" + else: + x = -0.01 + xanchor = "right" + textangle = 0 + + annotation_dict = dict( + textangle=textangle, + xanchor=xanchor, + yanchor=yanchor, + x=x, + y=y, + showarrow=False, + xref="paper", + yref="paper", + text=text, + font=dict(size=13, color=text_color), + ) + return annotation_dict + + +def list_of_options(iterable, conj="and", period=True): + """ + Returns an English listing of objects seperated by commas ',' + + For example, ['foo', 'bar', 'baz'] becomes 'foo, bar and baz' + if the conjunction 'and' is selected. + """ + if len(iterable) < 2: + raise exceptions.PlotlyError( + "Your list or tuple must contain at least 2 items." + ) + template = (len(iterable) - 2) * "{}, " + "{} " + conj + " {}" + period * "." + return template.format(*iterable) |
