Is there a way in Plotly to access colormap colours at any value along its range?
I know I can access the defining colours for a colourscale from
plotly.colors.PLOTLY_SCALES["Viridis"]
but I am unable to find how to access intermediate / interpolated values.
The equivalent in Matplotlib is shown in this question. There is also another question that address a similar question from the colorlover library, but neither offers a nice solution.
Plotly does not appear to have such a method, so I wrote one:
import plotly.colors
def get_continuous_color(colorscale, intermed):
"""
Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
color for any value in that range.
Plotly doesn't make the colorscales directly accessible in a common format.
Some are ready to use:
colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
Others are just swatches that need to be constructed into a colorscale:
viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
:param colorscale: A plotly continuous colorscale defined with RGB string colors.
:param intermed: value in the range [0, 1]
:return: color in rgb string format
:rtype: str
"""
if len(colorscale) < 1:
raise ValueError("colorscale must have at least one color")
if intermed <= 0 or len(colorscale) == 1:
return colorscale[0][1]
if intermed >= 1:
return colorscale[-1][1]
for cutoff, color in colorscale:
if intermed > cutoff:
low_cutoff, low_color = cutoff, color
else:
high_cutoff, high_color = cutoff, color
break
# noinspection PyUnboundLocalVariable
return plotly.colors.find_intermediate_color(
lowcolor=low_color, highcolor=high_color,
intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
colortype="rgb")
The challenge is that the built-in Plotly colorscales are not consistently exposed. Some are defined as a colorscale already, others as just a list of color swatches that must be converted to a color scale first.
The Viridis colorscale is defined with hex values, which the Plotly color manipulation methods don't like, so it's easiest to construct it from swatches like this:
viridis_colors, _ = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors)
get_continuous_color(colorscale, intermed=0.25)
# rgb(58.75, 80.75, 138.25)
There is a built in method from plotly.express.colors to sample_colorscale which would provide the color samples:
from plotly.express.colors import sample_colorscale
import plotly.graph_objects as go
import numpy as np
x = np.linspace(0, 1, 25)
c = sample_colorscale('jet', list(x))
fig = go.FigureWidget()
fig.add_trace(
go.Bar(x=x, y=y, marker_color=c)
)
fig.show()
See the output figure -> sampled_colors
This answer extend the already good one provided by Adam. In particular, it deals with the inconsistency of Plotly's color scales.
In Plotly, you specify a built-in color scale by writing colorscale="name_of_the_colorscale". This suggests that Plotly already has a built-in tool that somehow convert the color scale to an appropriate value and is capable of dealing with these inconsistencies. By searching Plotly's source code we find the useful ColorscaleValidator class. Let's see how to use it:
def get_color(colorscale_name, loc):
from _plotly_utils.basevalidators import ColorscaleValidator
# first parameter: Name of the property being validated
# second parameter: a string, doesn't really matter in our use case
cv = ColorscaleValidator("colorscale", "")
# colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...]
colorscale = cv.validate_coerce(colorscale_name)
if hasattr(loc, "__iter__"):
return [get_continuous_color(colorscale, x) for x in loc]
return get_continuous_color(colorscale, loc)
# Identical to Adam's answer
import plotly.colors
from PIL import ImageColor
def get_continuous_color(colorscale, intermed):
"""
Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
color for any value in that range.
Plotly doesn't make the colorscales directly accessible in a common format.
Some are ready to use:
colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
Others are just swatches that need to be constructed into a colorscale:
viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
:param colorscale: A plotly continuous colorscale defined with RGB string colors.
:param intermed: value in the range [0, 1]
:return: color in rgb string format
:rtype: str
"""
if len(colorscale) < 1:
raise ValueError("colorscale must have at least one color")
hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))
if intermed <= 0 or len(colorscale) == 1:
c = colorscale[0][1]
return c if c[0] != "#" else hex_to_rgb(c)
if intermed >= 1:
c = colorscale[-1][1]
return c if c[0] != "#" else hex_to_rgb(c)
for cutoff, color in colorscale:
if intermed > cutoff:
low_cutoff, low_color = cutoff, color
else:
high_cutoff, high_color = cutoff, color
break
if (low_color[0] == "#") or (high_color[0] == "#"):
# some color scale names (such as cividis) returns:
# [[loc1, "hex1"], [loc2, "hex2"], ...]
low_color = hex_to_rgb(low_color)
high_color = hex_to_rgb(high_color)
return plotly.colors.find_intermediate_color(
lowcolor=low_color,
highcolor=high_color,
intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
colortype="rgb",
)
At this point, all you have to do is:
get_color("phase", 0.5)
# 'rgb(123.99999999999999, 112.00000000000001, 236.0)'
import numpy as np
get_color("phase", np.linspace(0, 1, 256))
# ['rgb(167, 119, 12)',
# 'rgb(168.2941176470588, 118.0078431372549, 13.68235294117647)',
# ...
Edit: improvements to deal with special cases.
The official reference explains. Here
import plotly.express as px
print(px.colors.sequential.Viridis)
['#440154', '#482878', '#3e4989', '#31688e', '#26828e', '#1f9e89', '#35b779', '#6ece58', '#b5de2b', '#fde725']
print(px.colors.sequential.Viridis[0])
#440154
import plotly.express as px
color_list = list(name_of_color_scale)
# name_of_color_scale could be any in-built colorscale like px.colors.qualitative.D3.
Output:
color_list =
['#1F77B4',
'#FF7F0E',
'#2CA02C',
'#D62728',
'#9467BD',
'#8C564B',
'#E377C2',
'#7F7F7F',
'#BCBD22',
'#17BECF']
Related
I'd like to use two colors red and blue but with different concentration like below.
I'd like to convert this continuous color scale into discreate color scale with 10 discreate colors.
https://plotly.com/python/colorscales/#reversing-a-builtin-color-scale
If I print the continuous colorscale, it has only 2 element in the list like below. Now How can I get 10 discreate colors between red and blue with different concentrations. Thanks
colors=px.colors.sequential.Bluered_r
print(colors)
['rgb(255,0,0)', 'rgb(0,0,255)']
UPD
There is a simpler way with sample_colors from this answer
from plotly.express.colors import sample_colorscale
from sklearn.preprocessing import minmax_scale
colors_ = [1,5,6,7,8]
discrete_colors = sample_colorscale('Bluered', minmax_scale(colors_))
# colors_ = 5 numbers you are trying to depict with the colorscale
# discrete_colors - list of 5 rgb-coded colors from *Bluered* colorscale
# minmax_scale is used because *sample colors* can only deal with floats from [0,1]
Old answer
As far as I know, plotly doesn't have an explicit function for that.
For a red-blue scale a simple np.linspace-based implementation should work.
import numpy as np
def n_discrete_rgb_colors(color1: str, color2: str, n_colors: int) -> list:
color1_ = [int(i) for i in color1[4:-1].split(",")]
color2_ = [int(i) for i in color2[4:-1].split(",")]
colors_ = np.linspace(start = color1_, stop = color2_, num = n_colors)
colors = [str(f"rgb{int(i[0]),int(i[1]),int(i[2])}") for i in colors_]
return colors
color1, color2 = ['rgb(255,0,0)', 'rgb(0,0,255)']
# color1, color2 = px.colors.sequential.Bluered_r
n_discrete_rgb_colors(color1, color2, 10)
Output
['rgb(255, 0, 0)',
'rgb(226, 0, 28)',
'rgb(198, 0, 56)',
'rgb(170, 0, 85)',
'rgb(141, 0, 113)',
'rgb(113, 0, 141)',
'rgb(85, 0, 170)',
'rgb(56, 0, 198)',
'rgb(28, 0, 226)',
'rgb(0, 0, 255)']
Logic of the code is the following:
take two strings for edge colors in 'rgb(x,y,z)' format,
convert them into [x,y,z] lists,
build a linspace,
return this linspace with appropriate formatting
really what you are looking at is https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale
below code shows how to use a list comprehension to construct a discontinuous color scale from a continuous color scale
import pandas as pd
import numpy as np
import plotly.express as px
df = pd.DataFrame({c:np.linspace(1,10,100) for c in list("xyc")})
# https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale
cmap = [
(r, c)
for r, c in zip(
np.repeat(np.linspace(0, 1, len(px.colors.sequential.RdBu)+1), 2)[1:],
np.repeat(px.colors.sequential.RdBu,2),
)
]
px.bar(df, x="x", y="y", color="c", color_continuous_scale=cmap)
I want to plot a dendrogram plot for hierarchical clustering using plotly and show a small subset of the plot as with the large number of samples the plot can be very dense at the bottom.
I have plotted the plot using the plotly wrapper function create_dendrogram with the below code:
from scipy.cluster.hierarchy import linkage
import plotly.figure_factory as ff
fig = ff.create_dendrogram(test_df, linkagefun=lambda x: linkage(test_df, 'average', metric='euclidean'))
fig.update_layout(autosize=True, hovermode='closest')
fig.update_xaxes(mirror=False, showgrid=True, showline=False, showticklabels=False)
fig.update_yaxes(mirror=False, showgrid=True, showline=True)
fig.show()
And below is the plot using matplotlib which is used by default by the scipy library truncated to 4 levels for ease of interpretation:
from scipy.cluster.hierarchy import dendrogram,linkage
x = linkage(test_df,method='average')
dendrogram(x,truncate_mode='level',p=4)
plt.show()
As you can see the truncation is very useful to interpret large number of samples, how can i acheive this in plotly ?
There does not seem to be a straight-forward way to do this with ff.create_dendrogram(). This does not mean it's impossible though. But I would at least consider the brilliant functionalities that Dash Clustergram has to offer. If you insist on sticking to ff.create_dendrogram(), this is going to get a bit more messy than Plotly users rightfully have grown accustomed to. You haven't provided a data sample, so let's use the Plotly Basic Dendrogram example instead:
Plot 1
Code 1
import plotly.figure_factory as ff
import numpy as np
np.random.seed(1)
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = ff.create_dendrogram(X)
fig.update_layout(width=800, height=500)
f = fig.full_figure_for_development(warn=False)
fig.show()
The good news is that the exact same snippet will produce the following truncated plot after we've taken a few steps that I'll explain in the details below.
Plot 2
The details
If anyone who got this far in my answer knows a better way to do the following, then please share.
1. ff.create_dendrogram() is a wrapper for scipy.cluster.hierarchy.dendrogram
You can call help(ff.create_dendrogram) and learn that:
[...]This is a thin wrapper around scipy.cluster.hierarchy.dendrogram.
From the available arguments you can also see that none seem to handle anything related to truncating:
create_dendrogram(X, orientation='bottom', labels=None,
colorscale=None, distfun=None, linkagefun=<function at
0x0000016F09D4CEE0>, hovertext=None, color_threshold=None)
2. Take a closer look at scipy.cluster.hierarchy.dendrogram
Here we can see that some central elements have been left out after implementing the function in ff.create_dendrogram(X) when we compare it to the source:
scipy.cluster.hierarchy.dendrogram(Z, p=30, truncate_mode=None, color_threshold=None, get_leaves=True, orientation='top', labels=None, count_sort=False, distance_sort=False, show_leaf_counts=True, no_plot=False, no_labels=False, leaf_font_size=None, leaf_rotation=None, leaf_label_func=None, show_contracted=False, link_color_func=None, ax=None, above_threshold_color='C0')
truncate_mode should be exactly what we're looking for. So, now we know that scipy probably has all we need to build the foundation for a truncated dendrogram, but what's next?
3. Find where scipy.cluster.hierarchy.dendrogram is hiding in ff.create_dendrogram(X)
ff.create_dendrogram.__code__ will reveal where the source code exists in your system. In my case this is:
"C:\Users\vestland\Miniconda3\envs\dashy\lib\site-packages\plotly\figure_factory\_dendrogram.py"
So if you would like you can take a closer look at the complete source in your corresponding folder. If you do, you'll see one particularly interesting section where some attributes that we have listed above are taken care of:
def get_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
"""
Calculates all the elements needed for plotting a dendrogram.
.
.
.
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
)
Here we are at the very core of the problem. And the first step to a complete answer to your question is simply to include truncate_mode and p in P like this:
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
truncate_mode = 'level',
p = 2
)
And here's how you do that:
4. Monkey patching
In Python, the term monkey patch only refers to dynamic modifications of a class or module at runtime, which means monkey patch is a piece of Python code that extends or modifies other code at runtime. And here's the essence of how you can do exactly that in our case:
import plotly.figure_factory._dendrogram as original_dendrogram
original_dendrogram._Dendrogram.get_dendrogram_traces = modified_dendrogram_traces
Where modified_dendrogram_traces is the complete function definition of modified_dendrogram_traces() with the amendments I've already mentioned. As well as a few imports that will be missing that otherwise are run when you call import plotly.figure_factory as ff
Enough details for now. Below is the whole thing. If this is something you can use, we could perhaps make the whole thing a bit more dynamical than hardcoding truncate_mode = 'level' and p = 2.
Complete code:
from scipy.cluster.hierarchy import linkage
import plotly.figure_factory as ff
import plotly.figure_factory._dendrogram as original_dendrogram
import numpy as np
def modified_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
import plotly
from plotly import exceptions, optional_imports
np = optional_imports.get_module("numpy")
scp = optional_imports.get_module("scipy")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
scs = optional_imports.get_module("scipy.spatial")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
truncate_mode = 'level',
p = 2
)
icoord = scp.array(P["icoord"])
dcoord = scp.array(P["dcoord"])
ordered_labels = scp.array(P["ivl"])
color_list = scp.array(P["color_list"])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ["top", "bottom"]:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ["top", "bottom"]:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
hovertext_label = None
if hovertext:
hovertext_label = hovertext[i]
trace = dict(
type="scatter",
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode="lines",
marker=dict(color=colors[color_key]),
text=hovertext_label,
hoverinfo="text",
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ""
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ""
trace["xaxis"] = "x" + x_index
trace["yaxis"] = "y" + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
original_dendrogram._Dendrogram.get_dendrogram_traces = modified_dendrogram_traces
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = ff.create_dendrogram(X)
fig.update_layout(width=800, height=500)
f = fig.full_figure_for_development(warn=False)
fig.show()
To make it more dynamically you can pass **kwargs to create_dendogram() function. If you check the source code, you need to pass **kwargs in multiple other places, both in _Dendogram class and get_dendrogram_traces() function.
If you don't want to mess with _dendogram.py which located in the default directory, I advise you to copy whole file and create a new file (Lets say modified_dendogram.py) in your current directory.
Then simply import that local file using from modified_dendogram import create_dendrogram.
Now you can use all the arguments that scipy.cluster.hierarchy.dendrogram supports.
modified_dendogram.py:
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from collections import OrderedDict
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
scp = optional_imports.get_module("scipy")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
scs = optional_imports.get_module("scipy.spatial")
def create_dendrogram(
X,
orientation="bottom",
labels=None,
colorscale=None,
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
**kwargs
):
"""
Function that returns a dendrogram Plotly figure object. This is a thin
wrapper around scipy.cluster.hierarchy.dendrogram.
See also https://dash.plot.ly/dash-bio/clustergram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
:param (list) labels: List of axis category labels(observation labels)
:param (list) colorscale: Optional colorscale for the dendrogram tree.
Requires 8 colors to be specified, the 7th of
which is ignored. With scipy>=1.5.0, the 2nd, 3rd
and 6th are used twice as often as the others.
Given a shorter list, the missing values are
replaced with defaults and with a longer list the
extra values are ignored.
:param (function) distfun: Function to compute the pairwise distance from
the observations
:param (function) linkagefun: Function to compute the linkage matrix from
the pairwise distances
:param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
clusters
:param (double) color_threshold: Value at which the separation of clusters will be made
Example 1: Simple bottom oriented dendrogram
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(10,10)
>>> fig = create_dendrogram(X)
>>> fig.show()
Example 2: Dendrogram to put on the left of the heatmap
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(5,5)
>>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
>>> dendro = create_dendrogram(X, orientation='right', labels=names)
>>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
>>> dendro.show()
Example 3: Dendrogram with Pandas
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> import pandas as pd
>>> Index= ['A','B','C','D','E','F','G','H','I','J']
>>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
>>> fig = create_dendrogram(df, labels=Index)
>>> fig.show()
"""
if not scp or not scs or not sch:
raise ImportError(
"FigureFactory.create_dendrogram requires scipy, \
scipy.spatial and scipy.hierarchy"
)
s = X.shape
if len(s) != 2:
exceptions.PlotlyError("X should be 2-dimensional array.")
if distfun is None:
distfun = scs.distance.pdist
dendrogram = _Dendrogram(
X,
orientation,
labels,
colorscale,
distfun=distfun,
linkagefun=linkagefun,
hovertext=hovertext,
color_threshold=color_threshold,
kwargs=kwargs
)
return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
class _Dendrogram(object):
"""Refer to FigureFactory.create_dendrogram() for docstring."""
def __init__(
self,
X,
orientation="bottom",
labels=None,
colorscale=None,
width=np.inf,
height=np.inf,
xaxis="xaxis",
yaxis="yaxis",
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
kwargs=None
):
self.orientation = orientation
self.labels = labels
self.xaxis = xaxis
self.yaxis = yaxis
self.data = []
self.leaves = []
self.sign = {self.xaxis: 1, self.yaxis: 1}
self.layout = {self.xaxis: {}, self.yaxis: {}}
if self.orientation in ["left", "bottom"]:
self.sign[self.xaxis] = 1
else:
self.sign[self.xaxis] = -1
if self.orientation in ["right", "bottom"]:
self.sign[self.yaxis] = 1
else:
self.sign[self.yaxis] = -1
if distfun is None:
distfun = scs.distance.pdist
(dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
X, colorscale, distfun, linkagefun, hovertext, color_threshold, kwargs
)
self.labels = ordered_labels
self.leaves = leaves
yvals_flat = yvals.flatten()
xvals_flat = xvals.flatten()
self.zero_vals = []
for i in range(len(yvals_flat)):
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
self.zero_vals.append(xvals_flat[i])
if len(self.zero_vals) > len(yvals) + 1:
# If the length of zero_vals is larger than the length of yvals,
# it means that there are wrong vals because of the identicial samples.
# Three and more identicial samples will make the yvals of spliting
# center into 0 and it will accidentally take it as leaves.
l_border = int(min(self.zero_vals))
r_border = int(max(self.zero_vals))
correct_leaves_pos = range(
l_border, r_border + 1, int((r_border - l_border) / len(yvals))
)
# Regenerating the leaves pos from the self.zero_vals with equally intervals.
self.zero_vals = [v for v in correct_leaves_pos]
self.zero_vals.sort()
self.layout = self.set_figure_layout(width, height)
self.data = dd_traces
def get_color_dict(self, colorscale):
"""
Returns colorscale used for dendrogram tree clusters.
:param (list) colorscale: Colors to use for the plot in rgb format.
:rtype (dict): A dict of default colors mapped to the user colorscale.
"""
# These are the color codes returned for dendrograms
# We're replacing them with nicer colors
# This list is the colors that can be used by dendrogram, which were
# determined as the combination of the default above_threshold_color and
# the default color palette (see scipy/cluster/hierarchy.py)
d = {
"r": "red",
"g": "green",
"b": "blue",
"c": "cyan",
"m": "magenta",
"y": "yellow",
"k": "black",
# TODO: 'w' doesn't seem to be in the default color
# palette in scipy/cluster/hierarchy.py
"w": "white",
}
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
if colorscale is None:
rgb_colorscale = [
"rgb(0,116,217)", # blue
"rgb(35,205,205)", # cyan
"rgb(61,153,112)", # green
"rgb(40,35,35)", # black
"rgb(133,20,75)", # magenta
"rgb(255,65,54)", # red
"rgb(255,255,255)", # white
"rgb(255,220,0)", # yellow
]
else:
rgb_colorscale = colorscale
for i in range(len(default_colors.keys())):
k = list(default_colors.keys())[i] # PY3 won't index keys
if i < len(rgb_colorscale):
default_colors[k] = rgb_colorscale[i]
# add support for cyclic format colors as introduced in scipy===1.5.0
# before this, the colors were named 'r', 'b', 'y' etc., now they are
# named 'C0', 'C1', etc. To keep the colors consistent regardless of the
# scipy version, we try as much as possible to map the new colors to the
# old colors
# this mapping was found by inpecting scipy/cluster/hierarchy.py (see
# comment above).
new_old_color_map = [
("C0", "b"),
("C1", "g"),
("C2", "r"),
("C3", "c"),
("C4", "m"),
("C5", "y"),
("C6", "k"),
("C7", "g"),
("C8", "r"),
("C9", "c"),
]
for nc, oc in new_old_color_map:
try:
default_colors[nc] = default_colors[oc]
except KeyError:
# it could happen that the old color isn't found (if a custom
# colorscale was specified), in this case we set it to an
# arbitrary default.
default_colors[nc] = "rgb(0,116,217)"
return default_colors
def set_axis_layout(self, axis_key):
"""
Sets and returns default axis object for dendrogram figure.
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
:rtype (dict): An axis_key dictionary with set parameters.
"""
axis_defaults = {
"type": "linear",
"ticks": "outside",
"mirror": "allticks",
"rangemode": "tozero",
"showticklabels": True,
"zeroline": False,
"showgrid": False,
"showline": True,
}
if len(self.labels) != 0:
axis_key_labels = self.xaxis
if self.orientation in ["left", "right"]:
axis_key_labels = self.yaxis
if axis_key_labels not in self.layout:
self.layout[axis_key_labels] = {}
self.layout[axis_key_labels]["tickvals"] = [
zv * self.sign[axis_key] for zv in self.zero_vals
]
self.layout[axis_key_labels]["ticktext"] = self.labels
self.layout[axis_key_labels]["tickmode"] = "array"
self.layout[axis_key].update(axis_defaults)
return self.layout[axis_key]
def set_figure_layout(self, width, height):
"""
Sets and returns default layout object for dendrogram figure.
"""
self.layout.update(
{
"showlegend": False,
"autosize": False,
"hovermode": "closest",
"width": width,
"height": height,
}
)
self.set_axis_layout(self.xaxis)
self.set_axis_layout(self.yaxis)
return self.layout
def get_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold, kwargs={}
):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
**kwargs
)
icoord = scp.array(P["icoord"])
dcoord = scp.array(P["dcoord"])
ordered_labels = scp.array(P["ivl"])
color_list = scp.array(P["color_list"])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ["top", "bottom"]:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ["top", "bottom"]:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
hovertext_label = None
if hovertext:
hovertext_label = hovertext[i]
trace = dict(
type="scatter",
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode="lines",
marker=dict(color=colors[color_key]),
text=hovertext_label,
hoverinfo="text",
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ""
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ""
trace["xaxis"] = "x" + x_index
trace["yaxis"] = "y" + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
Example:
from modified_dendogram import create_dendrogram
import numpy as np
np.random.seed(1)
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = create_dendrogram(X)
fig.update_layout(width=800, height=500)
fig.show()
from utils.modified_dendogram import create_dendrogram
import numpy as np
np.random.seed(1)
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = create_dendrogram(X, truncate_mode="level", p=1)
fig.update_layout(width=800, height=500)
fig.show()
EDIT: I figured out that the Problem always occours if one tries to plot to two different lists of figures. Does that mean that one can not do plots to different figure-lists in the same loop? See latest code for much simpler sample of a problem.
I try to analyze a complex set of data which consists basically about measurements of electric devices under different conditions. Hence, the code is a bit more complex but I tried to strip it down to a working example - however it is still pretty long. Hence, let me explain what you see: You see 3 classes with Transistor representing an electronic device. It's attribute Y represents the measurement data - consisting of 2 sets of measurements. Each Transistor belongs to a group - 2 in this example. And some groups belong to the same series - one series where both groups are included in this example.
The aim is now to plot all measurement data for each Transistor (not shown), then to also plot all data belonging to the same group in one plot each and all data of the same series to one plot. In order to program it in an efficent way without having a lot of loops my idea was to use the object orientated nature of matplotlib - I will have figures and subplots for each level of plotting (initialized in initGrpPlt and initSeriesPlt) which are then filled with only one loop over all Transistors (in MainPlt: toGPlt and toSPlt). In the end it should only be printed / saved to a file / whatever (PltGrp and PltSeries).
The Problem: Even though I specify where to plot, python plots the series plots into the group plots. You can check this yourself by running the code with the line 'toSPlt(trans,j)' and without. I have no clue why python does this because in the function toSPlt I explicetly say that python should use the subplots from the series-subplot-list. Would anyone have an idea to why this is like this and how to solve this problem in an elegent way?
Read the code from the bottom to the top, that should help with understanding.
Kind regards
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
maxNrVdrain = 2
X = np.linspace(-np.pi, np.pi, 256,endpoint=True)
A = [[1*np.cos(X),2*np.cos(X),3*np.cos(X),4*np.cos(X)],[1*np.tan(X),2*np.tan(X),3*np.tan(X),4*np.tan(X)]]
B = [[2* np.sin(X),4* np.sin(X),6* np.sin(X),8* np.sin(X)],[2*np.cos(X),4*np.cos(X),6*np.cos(X),8*np.cos(X)]]
class Transistor(object):
_TransRegistry = []
def __init__(self,y1,y2):
self._TransRegistry.append(self)
self.X = X
self.Y = [y1,y2]
self.group = ''
class Groups():
_GroupRegistry = []
def __init__(self,trans):
self._GroupRegistry.append(self)
self.transistors = [trans]
self.figlist = []
self.axlist = []
class Series():
_SeriesRegistry = []
def __init__(self,group):
self._SeriesRegistry.append(self)
self.groups = [group]
self.figlist = []
self.axlist = []
def initGrpPlt():
for group in Groups._GroupRegistry:
for j in range(maxNrVdrain):
group.figlist.append(plt.figure(j))
group.axlist.append(group.figlist[j].add_subplot(111))
return
def initSeriesPlt():
for series in Series._SeriesRegistry:
for j in range(maxNrVdrain):
series.figlist.append(plt.figure(j))
series.axlist.append(series.figlist[j].add_subplot(111))
return
def toGPlt(trans,j):
colour = cm.rainbow(np.linspace(0, 1, 4))
group = trans.group
group.axlist[j].plot(trans.X,trans.Y[j], color=colour[group.transistors.index(trans)], linewidth=1.5, linestyle="-")
return
def toSPlt(trans,j):
colour = cm.rainbow(np.linspace(0, 1, 2))
series = Series._SeriesRegistry[0]
group = trans.group
if group.transistors.index(trans) == 0:
series.axlist[j].plot(trans.X,trans.Y[j],color=colour[series.groups.index(group)], linewidth=1.5, linestyle="-", label = 'T = nan, RH = nan' )
else:
series.axlist[j].plot(trans.X,trans.Y[j],color=colour[series.groups.index(group)], linewidth=1.5, linestyle="-")
return
def PltGrp(group,j):
ax = group.axlist[j]
ax.set_title('Test Grp')
return
def PltSeries(series,j):
ax = series.axlist[j]
ax.legend(loc='upper right', frameon=False)
ax.set_title('Test Series')
return
def MainPlt():
initGrpPlt()
initSeriesPlt()
for trans in Transistor._TransRegistry:
for j in range(maxNrVdrain):
toGPlt(trans,j)
toSPlt(trans,j)#plots to group plot for some reason
for j in range(maxNrVdrain):
for group in Groups._GroupRegistry:
PltGrp(group,j)
plt.show()
return
def Init():
for j in range(4):
trans = Transistor(A[0][j],A[1][j])
if j == 0:
Groups(trans)
else:
Groups._GroupRegistry[0].transistors.append(trans)
trans.group = Groups._GroupRegistry[0]
Series(Groups._GroupRegistry[0])
for j in range(4):
trans = Transistor(B[0][j],B[1][j])
if j == 0:
Groups(trans)
else:
Groups._GroupRegistry[1].transistors.append(trans)
trans.group = Groups._GroupRegistry[1]
Series._SeriesRegistry[0].groups.append(Groups._GroupRegistry[1])
return
def main():
Init()
MainPlt()
return
main()
latest example that does not work:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
X = np.linspace(-np.pi, np.pi, 256,endpoint=True)
Y1 = np.cos(X)
Y2 = np.sin(X)
figlist1 = []
figlist2 = []
axlist1 = []
axlist2 = []
for j in range(4):
figlist1.append(plt.figure(j))
axlist1.append(figlist1[j].add_subplot(111))
figlist2.append(plt.figure(j))#this should be a new set of figures!
axlist2.append(figlist2[j].add_subplot(111))
colour = cm.rainbow(np.linspace(0, 1, 4))
axlist1[j].plot(X,j*Y1, color=colour[j], linewidth=1.5, linestyle="-")
axlist1[j].set_title('Test Grp 1')
colour = cm.rainbow(np.linspace(0, 1, 4))
axlist2[j].plot(X,j*Y2, color=colour[int(j/2)], linewidth=1.5, linestyle="-")
axlist2[j].set_title('Test Grp 2')
plt.show()
Ok, stupid mistake if one thinks of the Background but maybe someone has a similar Problem and is unable to see the cause as I was first. So here is the solution:
The Problem is that the Name of the listobjects like figlist1[j] do not define the figure - they are just pointers to the actual figure object. and if such an object is created by plt.figure(j) one has to make sure that j is different for each figure - hence, in a Loop where multiple figures shall be initialized one Needs to somehow Change the number of the figure or the first object will be overwritten. Hope that helps! Cheers.
How can a figure using a rainbow colormap, such as figure 1, be converted so that the same data are displayed using a different color map, such as a perceptually uniform sequential map?
Assume that the underlying data from which the original image was generated are not accessible and the image itself must be recolored using only information within the image.
Background information: rainbow color maps tend to produce visual artifacts. See the cyan line near z = -1.15 m? It looks like there's a sharp edge there. But look at the colorbar itself! Even the color bar has an edge there. There's another fake edge in the yellow band that goes vertically near R = 1.45 m. The horizontal yellow stripe may be a real edge in the underlying data, although it's difficult to distinguish that case from a rainbow artifact.
More information:
http://ieeexplore.ieee.org/abstract/document/4118486/
http://matplotlib.org/users/colormaps.html
Here is my best solution so far:
import numpy as np
import scipy
import os
import matplotlib
import copy
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread, imsave
def_colorbar_loc = [[909, 22], [953 - 20, 959]]
def_working_loc = [[95, 189], [857, 708]]
def recolor_image(
filename='image.png',
colorbar_loc=def_colorbar_loc,
working_loc=def_working_loc,
colorbar_orientation='auto',
colorbar_direction=-1,
new_cmap='viridis',
normalize_before_compare=False,
max_rgb='auto',
threshold=0.4,
saturation_threshold=0.25,
compare_hue=True,
show_plot=True,
debug=False,
):
"""
This script reads in an image file (like .png), reads the image's color bar (you have to tell it where), interprets
the color map used in the image to convert colors to values, then recolors those values with a new color map and
regenerates the figure. Useful for fixing figures that were made with rainbow color maps.
Parameters
-----------
:param filename: Full path and filename of the image file.
:param colorbar_loc: Location of color bar, which will be used to analyze the image and convert colors into values.
Pixels w/ 0,0 at top left corner: [[left, top], [right, bottom]]
:param working_loc: Location of the area to recolor. You don't have to recolor the whole image.
Pixels w/ 0,0 at top left corner: [[left, top], [right, bottom]], set to [[0, 0], [-1, -1]] to do everything.
:param colorbar_orientation: Set to 'x', 'y', or 'auto' to specify whether color map is horizontal, vertical,
or should be determined based on the dimensions of the colorbar_loc
:param colorbar_direction: Controls direction of ascending value
+1: colorbar goes from top to bottom or left to right.
-1: colorbar goes from bottom to top or right to left.
:param new_cmap: String describing the new color map to use in the recolored image.
:param normalize_before_compare: Divide r, g, and b each by (r+g+b) before comparing.
:param max_rgb: Do the values of r, g, and b range from 0 to 1 or from 0 to 255? Set to 1, 255, or 'auto'.
:param threshold: Sum of absolute differences in r, g, b values must be less than threshold to be valid
(0 = perfect, 3 = impossibly bad). Higher numbers = less chance of missing pixels but more chance of recoloring
plot axes, etc.
:param saturation_threshold: Minimum color saturation below which no replacement will take place
:param compare_hue: Use differences in HSV instead of RGB to determine with which index each pixel should be
associated.
:param show_plot: T/F: Open a plot to explain what is going on. Also helpful for checking your aim on the colorbar
coordinates and debugging.
:param debug: T/F: Print debugging information.
"""
def printd(string_in):
"""
Prints debugging statements
:param string_in: String to print only if debug is on.
:return: None
"""
if debug:
print(string_in)
return
print('Recoloring image: {:} ...'.format(filename))
# Determine tag name and load original file into the tree
fn1 = filename.split(os.sep)[-1] # Filename without path
fn2 = fn1.split(os.extsep)[0] # Filename without extension (so new filename can be built later)
ext = fn1.split(os.extsep)[-1] # File extension
path = os.sep.join(filename.split(os.sep)[0:-1]) # Path; used later to save results.
a = imread(filename).astype(float)
printd(f'Read image; shape = {np.shape(a)}')
if max_rgb == 'auto':
# Determine if values of R, G, and B range from 0 to 1 or from 0 to 255
if a.max() > 1:
max_rgb = 255.0
else:
max_rgb = 1.0
# Normalize a so RGB values go from 0 to 1 and are floats.
a /= max_rgb
# Extract the colorbar
x = np.array([colorbar_loc[0][0], colorbar_loc[1][0]])
y = np.array([colorbar_loc[0][1], colorbar_loc[1][1]])
cb = a[y[0]:y[1], x[0]:x[1]]
# Take just the working area, not the whole image
xw = np.array([working_loc[0][0], working_loc[1][0]])
yw = np.array([working_loc[0][1], working_loc[1][1]])
a1 = a[yw[0]:yw[1], xw[0]:xw[1]]
# Pick color bar orientation
if colorbar_orientation == 'auto':
if np.diff(x) > np.diff(y):
colorbar_orientation = 'x'
else:
colorbar_orientation = 'y'
printd('Auto selected colorbar_orientation')
printd('Colorbar orientation is {:}'.format(colorbar_orientation))
# Analyze the colorbar
if colorbar_orientation == 'y':
cb = np.nanmean(cb, axis=1)
else:
cb = np.nanmean(cb, axis=0)
if colorbar_direction < 0:
cb = cb[::-1]
# Compress colorbar to only count unique colors
# If the array gets too big, it will fill memory and crash python: https://github.com/numpy/numpy/issues/14136
dcb = np.append(1, np.sum(abs(np.diff(cb[:, 0:3], axis=0)), axis=1))
cb = cb[dcb > 0]
# Find and mask of special colors that should not be recolored
n1a = np.sum(a1[:, :, 0:3], axis=2)
replacement_mask = np.ones(np.shape(n1a), bool)
for col in [0, 3]: # Black and white will come out as 0 and 3.
mask_update = n1a != col
if mask_update.max() == 0:
print('Warning: masking to protect special colors prevented all changes to the image!')
else:
printd('Good: Special color mask {:} allowed at least some changes'.format(col))
replacement_mask *= mask_update
if replacement_mask.max() == 0:
print('Warning: replacement mask will prevent all changes to the image! '
'(Reached this point during special color protection)')
printd('Sum(replacement_mask) = {:} (after considering special color {:})'
.format(np.sum(np.atleast_1d(replacement_mask)), col))
# Also apply limits to total r+g+b
replacement_mask *= n1a > 0.75
replacement_mask *= n1a < 2.5
if replacement_mask.max() == 0:
print('Warning: replacement mask will prevent all changes to the image! '
'(Reached this point during total r+g+b+ limits)')
printd('Sum(replacement_mask) = {:} (after considering r+g+b upper threshold)'
.format(np.sum(np.atleast_1d(replacement_mask))))
if saturation_threshold > 0:
hsv1 = matplotlib.colors.rgb_to_hsv(a1[:, :, 0:3])
sat = hsv1[:, :, 1]
printd('Saturation ranges from {:} <= sat <= {:}'.format(sat.min(), sat.max()))
sat_mask = sat > saturation_threshold
if sat_mask.max() == 0:
print('Warning: saturation mask will prevent all changes to the image!')
else:
printd('Good: Saturation mask will allow at least some changes')
replacement_mask *= sat_mask
if replacement_mask.max() == 0:
print('Warning: replacement mask will prevent all changes to the image! '
'(Reached this point during saturation threshold)')
printd(f'shape(a1) = {np.shape(a)}')
printd(f'shape(cb) = {np.shape(cb)}')
# Find where on the colorbar each pixel sits
if compare_hue:
# Difference in hue
hsv1 = matplotlib.colors.rgb_to_hsv(a1[:, :, 0:3])
hsv_cb = matplotlib.colors.rgb_to_hsv(cb[:, 0:3])
d2 = abs(hsv1[:, :, :, np.newaxis] - hsv_cb.T[np.newaxis, np.newaxis, :, :])
# d2 = d2[:, :, 0, :] # Take hue only
d2 = np.sum(d2, axis=2)
printd(' shape(d2) = {:} (hue version)'.format(np.shape(d2)))
else:
# Difference in RGB
if normalize_before_compare:
# Difference of normalized RGB arrays
n1 = n1a[:, :, np.newaxis]
n2 = np.sum(cb[:, 0:3], axis=1)[:, np.newaxis]
w1 = n1 == 0
w2 = n2 == 0
n1[w1] = 1
n2[w2] = 1
d = (a1/n1)[:, :, 0:3, np.newaxis] - (cb/n2).T[np.newaxis, np.newaxis, 0:3, :]
else:
# Difference of non-normalized RGB arrays
d = (a1[:, :, 0:3, np.newaxis] - cb.T[np.newaxis, np.newaxis, 0:3, :])
printd(f'Shape(d) = {np.shape(d)}')
d2 = np.sum(np.abs(d[:, :, 0:3, :]), axis=2) # 0:3 excludes the alpha channel from this calculation
printd('Processed colorbar')
index = d2.argmin(axis=2)
md2 = d2.min(axis=2)
index_valid = md2 < threshold
if index_valid.max() == 0:
print('Warning: minimum difference is greater than threshold: all changes rejected!')
else:
printd('Good: Minimum difference filter is lower than threshold for at least one pixel.')
printd('Sum(index_valid) = {:} (before *= replacement_mask)'.format(np.sum(np.atleast_1d(index_valid))))
printd('Sum(replacement_mask) = {:} (final, before combining w/ index_valid)'
.format(np.sum(np.atleast_1d(replacement_mask))))
index_valid *= replacement_mask
if index_valid.max() == 0:
print('Warning: index_valid mask prevents all changes to the image after combination w/ replacement_mask.')
else:
printd('Good: Mask will allow at least one pixel to change.')
printd('Sum(index_valid) = {:}'.format(np.sum(np.atleast_1d(index_valid))))
value = index/(len(cb)-1.0)
printd('Index ranges from {:} to {:}'.format(index.min(), index.max()))
# Make a new image with replaced colors
b = matplotlib.cm.ScalarMappable(cmap=new_cmap).to_rgba(value) # Remap everything
printd('shape(b) = {:}, min(b) = {:}, max(b) = {:}'.format(np.shape(b), b.min(), b.max()))
c = copy.copy(a1) # Copy original
c[index_valid] = b[index_valid] # Transfer only pixels where color was close to colormap
# Transfer working area to full image
c2 = copy.copy(a) # Copy original full image
c2[yw[0]:yw[1], xw[0]:xw[1], :] = c # Replace working area
c2[:, :, 3] = a[:, :, 3] # Preserve original alpha channel
# Save the image in the same path as the original but with _recolored added to the filename.
new_filename = '{:}{:}{:}_recolored{:}{:}'.format(path, os.sep, fn2, os.extsep, ext)
imsave(new_filename, c2)
print('Done recoloring. Result saved to {:} .'.format(new_filename))
if show_plot:
# Setup figure for showing things to the user
f, axs = plt.subplots(2, 3)
axo = axs[0, 0] # Axes for original figure
axoc = axs[0, 1] # Axes for original color bar
axf = axs[0, 2] # Axes for final figure
axm = axs[1, 1] # Axes for mask
axre = axs[1, 2] # Axes for recolored section only (it might not be the whole figure)
axraw = axs[1, 0] # Axes for raw recoloring result before masking
for ax in axs.flatten():
ax.set_xlabel('x pixel')
ax.set_ylabel('y pixel')
axo.set_title('Original image w/ colorbar ID overlay')
axoc.set_title('Color progression from original colorbar')
axm.set_title('Mask')
axre.set_title('Recolored section')
axraw.set_title('Raw recolor result (no masking)')
axf.set_title('Final image')
axoc.set_xlabel('Index')
axoc.set_ylabel('Value')
# Show the user where they placed the color bar and working location
axo.imshow(a)
xx = x[np.array([0, 0, 1, 1, 0])]
yy = y[np.array([0, 1, 1, 0, 0])]
axo.plot(xx, yy, '+-', label='colorbar')
xxw = xw[np.array([0, 0, 1, 1, 0])]
yyw = yw[np.array([0, 1, 1, 0, 0])]
axo.plot(xxw, yyw, '+-', label='target')
tots = np.sum(cb[:, 0:3], axis=1)
if normalize_before_compare:
# Normalized version
axoc.plot(cb[:, 0] / tots, 'r', label='r/(r+g+b)', lw=2)
axoc.plot(cb[:, 1] / tots, 'g', label='g/(r+g+b)', lw=2)
axoc.plot(cb[:, 2] / tots, 'b', label='b/(r+g+b)', lw=2)
axoc.set_ylabel('Normalized value')
else:
axoc.plot(cb[:, 0], 'r', label='r', lw=2)
axoc.plot(cb[:, 1], 'g', label='g', lw=2)
axoc.plot(cb[:, 2], 'b', label='b', lw=2)
axoc.plot(cb[:, 3], color='gray', linestyle='--', label='$\\alpha$')
axoc.plot(tots, 'k', label='r+g+b')
# Display the new colors with no mask, the mask, and the recolored section
axraw.imshow(b)
axm.imshow(index_valid)
axre.imshow(c)
# Display the final result
axf.imshow(c2)
# Finishing touches on plots
axo.legend(loc=0).set_draggable(True)
axoc.legend(loc=0).set_draggable(True)
plt.show()
return
I am writing a script that is to combine multiple histograms into a stack plot. The script is to be able to handle a fairly arbitrary number of histograms. I want to be able to color the histograms in the stack plot according to a defined color palette and also to be able to extend that color palette when it is insufficient to color the number of histograms that the script is to deal with.
I have created functions to handle getting the mean of colors and had in mind to try extending a defined palette by mixing colors in some automated way and then extending the palette by adding those mixed colors, but I'm not sure how to do this in a structured, sensible way. I request guidance and suggestions on how to extend a palette in an automated way.
style1 = [
"#FC0000",
"#FFAE3A",
"#00AC00",
"#6665EC",
"#A9A9A9"
]
def clamp(x):
return(max(0, min(x, 255)))
def RGB_to_HEX(RGB_tuple):
# This function returns a HEX string given an RGB tuple.
r = RGB_tuple[0]
g = RGB_tuple[1]
b = RGB_tuple[2]
return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g), clamp(b))
def HEX_to_RGB(HEX_string):
# This function returns an RGB tuple given a HEX string.
HEX = HEX_string.lstrip('#')
HEX_length = len(HEX)
return tuple(
int(HEX[i:i + HEX_length // 3], 16) for i in range(
0,
HEX_length,
HEX_length // 3
)
)
def mean_color(colorsInHEX):
# This function returns a HEX string that represents the mean color of a
# list of colors represented by HEX strings.
colorsInRGB = []
for colorInHEX in colorsInHEX:
colorsInRGB.append(HEX_to_RGB(colorInHEX))
sum_r = 0
sum_g = 0
sum_b = 0
for colorInRGB in colorsInRGB:
sum_r += colorInRGB[0]
sum_g += colorInRGB[1]
sum_b += colorInRGB[2]
mean_r = sum_r / len(colorsInRGB)
mean_g = sum_g / len(colorsInRGB)
mean_b = sum_b / len(colorsInRGB)
return RGB_to_HEX((mean_r, mean_g, mean_b))
def extend_palette(
colors = None, # a list of HEX string colors
numberOfColorsNeeded = None # number of colors to which list should be extended
)
# magic happens here
return colors_extended
print(
extend_palette(
colors = style1,
10
)
)
print(
extend_palette(
colors = style1,
50
)
)
You can use colorir for that.
from colorir import *
pal = Palette.load("spectral") # Load a categorical palette, a full list can be found in the docs
grad = Grad(pal) # Object to automatically "mix" the colors
# Now to generate a dynamic list of colors based on the number of inputs:
for i in range(15):
# Get a bigger list by interpolating the colors if necessary
if i > len(pal):
colors = grad.n_colors(i)
else:
colors = pal.colors
# Use the 'colors' list for your plots
# plt.hist(..., color=colors)