I want to plot a dendrogram plot for hierarchical clustering using plotly and show a small subset of the plot as with the large number of samples the plot can be very dense at the bottom.
I have plotted the plot using the plotly wrapper function create_dendrogram with the below code:
from scipy.cluster.hierarchy import linkage
import plotly.figure_factory as ff
fig = ff.create_dendrogram(test_df, linkagefun=lambda x: linkage(test_df, 'average', metric='euclidean'))
fig.update_layout(autosize=True, hovermode='closest')
fig.update_xaxes(mirror=False, showgrid=True, showline=False, showticklabels=False)
fig.update_yaxes(mirror=False, showgrid=True, showline=True)
fig.show()
And below is the plot using matplotlib which is used by default by the scipy library truncated to 4 levels for ease of interpretation:
from scipy.cluster.hierarchy import dendrogram,linkage
x = linkage(test_df,method='average')
dendrogram(x,truncate_mode='level',p=4)
plt.show()
As you can see the truncation is very useful to interpret large number of samples, how can i acheive this in plotly ?
There does not seem to be a straight-forward way to do this with ff.create_dendrogram(). This does not mean it's impossible though. But I would at least consider the brilliant functionalities that Dash Clustergram has to offer. If you insist on sticking to ff.create_dendrogram(), this is going to get a bit more messy than Plotly users rightfully have grown accustomed to. You haven't provided a data sample, so let's use the Plotly Basic Dendrogram example instead:
Plot 1
Code 1
import plotly.figure_factory as ff
import numpy as np
np.random.seed(1)
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = ff.create_dendrogram(X)
fig.update_layout(width=800, height=500)
f = fig.full_figure_for_development(warn=False)
fig.show()
The good news is that the exact same snippet will produce the following truncated plot after we've taken a few steps that I'll explain in the details below.
Plot 2
The details
If anyone who got this far in my answer knows a better way to do the following, then please share.
1. ff.create_dendrogram() is a wrapper for scipy.cluster.hierarchy.dendrogram
You can call help(ff.create_dendrogram) and learn that:
[...]This is a thin wrapper around scipy.cluster.hierarchy.dendrogram.
From the available arguments you can also see that none seem to handle anything related to truncating:
create_dendrogram(X, orientation='bottom', labels=None,
colorscale=None, distfun=None, linkagefun=<function at
0x0000016F09D4CEE0>, hovertext=None, color_threshold=None)
2. Take a closer look at scipy.cluster.hierarchy.dendrogram
Here we can see that some central elements have been left out after implementing the function in ff.create_dendrogram(X) when we compare it to the source:
scipy.cluster.hierarchy.dendrogram(Z, p=30, truncate_mode=None, color_threshold=None, get_leaves=True, orientation='top', labels=None, count_sort=False, distance_sort=False, show_leaf_counts=True, no_plot=False, no_labels=False, leaf_font_size=None, leaf_rotation=None, leaf_label_func=None, show_contracted=False, link_color_func=None, ax=None, above_threshold_color='C0')
truncate_mode should be exactly what we're looking for. So, now we know that scipy probably has all we need to build the foundation for a truncated dendrogram, but what's next?
3. Find where scipy.cluster.hierarchy.dendrogram is hiding in ff.create_dendrogram(X)
ff.create_dendrogram.__code__ will reveal where the source code exists in your system. In my case this is:
"C:\Users\vestland\Miniconda3\envs\dashy\lib\site-packages\plotly\figure_factory\_dendrogram.py"
So if you would like you can take a closer look at the complete source in your corresponding folder. If you do, you'll see one particularly interesting section where some attributes that we have listed above are taken care of:
def get_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
"""
Calculates all the elements needed for plotting a dendrogram.
.
.
.
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
)
Here we are at the very core of the problem. And the first step to a complete answer to your question is simply to include truncate_mode and p in P like this:
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
truncate_mode = 'level',
p = 2
)
And here's how you do that:
4. Monkey patching
In Python, the term monkey patch only refers to dynamic modifications of a class or module at runtime, which means monkey patch is a piece of Python code that extends or modifies other code at runtime. And here's the essence of how you can do exactly that in our case:
import plotly.figure_factory._dendrogram as original_dendrogram
original_dendrogram._Dendrogram.get_dendrogram_traces = modified_dendrogram_traces
Where modified_dendrogram_traces is the complete function definition of modified_dendrogram_traces() with the amendments I've already mentioned. As well as a few imports that will be missing that otherwise are run when you call import plotly.figure_factory as ff
Enough details for now. Below is the whole thing. If this is something you can use, we could perhaps make the whole thing a bit more dynamical than hardcoding truncate_mode = 'level' and p = 2.
Complete code:
from scipy.cluster.hierarchy import linkage
import plotly.figure_factory as ff
import plotly.figure_factory._dendrogram as original_dendrogram
import numpy as np
def modified_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
import plotly
from plotly import exceptions, optional_imports
np = optional_imports.get_module("numpy")
scp = optional_imports.get_module("scipy")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
scs = optional_imports.get_module("scipy.spatial")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
truncate_mode = 'level',
p = 2
)
icoord = scp.array(P["icoord"])
dcoord = scp.array(P["dcoord"])
ordered_labels = scp.array(P["ivl"])
color_list = scp.array(P["color_list"])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ["top", "bottom"]:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ["top", "bottom"]:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
hovertext_label = None
if hovertext:
hovertext_label = hovertext[i]
trace = dict(
type="scatter",
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode="lines",
marker=dict(color=colors[color_key]),
text=hovertext_label,
hoverinfo="text",
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ""
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ""
trace["xaxis"] = "x" + x_index
trace["yaxis"] = "y" + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
original_dendrogram._Dendrogram.get_dendrogram_traces = modified_dendrogram_traces
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = ff.create_dendrogram(X)
fig.update_layout(width=800, height=500)
f = fig.full_figure_for_development(warn=False)
fig.show()
To make it more dynamically you can pass **kwargs to create_dendogram() function. If you check the source code, you need to pass **kwargs in multiple other places, both in _Dendogram class and get_dendrogram_traces() function.
If you don't want to mess with _dendogram.py which located in the default directory, I advise you to copy whole file and create a new file (Lets say modified_dendogram.py) in your current directory.
Then simply import that local file using from modified_dendogram import create_dendrogram.
Now you can use all the arguments that scipy.cluster.hierarchy.dendrogram supports.
modified_dendogram.py:
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from collections import OrderedDict
from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs
# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
scp = optional_imports.get_module("scipy")
sch = optional_imports.get_module("scipy.cluster.hierarchy")
scs = optional_imports.get_module("scipy.spatial")
def create_dendrogram(
X,
orientation="bottom",
labels=None,
colorscale=None,
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
**kwargs
):
"""
Function that returns a dendrogram Plotly figure object. This is a thin
wrapper around scipy.cluster.hierarchy.dendrogram.
See also https://dash.plot.ly/dash-bio/clustergram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
:param (list) labels: List of axis category labels(observation labels)
:param (list) colorscale: Optional colorscale for the dendrogram tree.
Requires 8 colors to be specified, the 7th of
which is ignored. With scipy>=1.5.0, the 2nd, 3rd
and 6th are used twice as often as the others.
Given a shorter list, the missing values are
replaced with defaults and with a longer list the
extra values are ignored.
:param (function) distfun: Function to compute the pairwise distance from
the observations
:param (function) linkagefun: Function to compute the linkage matrix from
the pairwise distances
:param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
clusters
:param (double) color_threshold: Value at which the separation of clusters will be made
Example 1: Simple bottom oriented dendrogram
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(10,10)
>>> fig = create_dendrogram(X)
>>> fig.show()
Example 2: Dendrogram to put on the left of the heatmap
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> X = np.random.rand(5,5)
>>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
>>> dendro = create_dendrogram(X, orientation='right', labels=names)
>>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
>>> dendro.show()
Example 3: Dendrogram with Pandas
>>> from plotly.figure_factory import create_dendrogram
>>> import numpy as np
>>> import pandas as pd
>>> Index= ['A','B','C','D','E','F','G','H','I','J']
>>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
>>> fig = create_dendrogram(df, labels=Index)
>>> fig.show()
"""
if not scp or not scs or not sch:
raise ImportError(
"FigureFactory.create_dendrogram requires scipy, \
scipy.spatial and scipy.hierarchy"
)
s = X.shape
if len(s) != 2:
exceptions.PlotlyError("X should be 2-dimensional array.")
if distfun is None:
distfun = scs.distance.pdist
dendrogram = _Dendrogram(
X,
orientation,
labels,
colorscale,
distfun=distfun,
linkagefun=linkagefun,
hovertext=hovertext,
color_threshold=color_threshold,
kwargs=kwargs
)
return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
class _Dendrogram(object):
"""Refer to FigureFactory.create_dendrogram() for docstring."""
def __init__(
self,
X,
orientation="bottom",
labels=None,
colorscale=None,
width=np.inf,
height=np.inf,
xaxis="xaxis",
yaxis="yaxis",
distfun=None,
linkagefun=lambda x: sch.linkage(x, "complete"),
hovertext=None,
color_threshold=None,
kwargs=None
):
self.orientation = orientation
self.labels = labels
self.xaxis = xaxis
self.yaxis = yaxis
self.data = []
self.leaves = []
self.sign = {self.xaxis: 1, self.yaxis: 1}
self.layout = {self.xaxis: {}, self.yaxis: {}}
if self.orientation in ["left", "bottom"]:
self.sign[self.xaxis] = 1
else:
self.sign[self.xaxis] = -1
if self.orientation in ["right", "bottom"]:
self.sign[self.yaxis] = 1
else:
self.sign[self.yaxis] = -1
if distfun is None:
distfun = scs.distance.pdist
(dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
X, colorscale, distfun, linkagefun, hovertext, color_threshold, kwargs
)
self.labels = ordered_labels
self.leaves = leaves
yvals_flat = yvals.flatten()
xvals_flat = xvals.flatten()
self.zero_vals = []
for i in range(len(yvals_flat)):
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
self.zero_vals.append(xvals_flat[i])
if len(self.zero_vals) > len(yvals) + 1:
# If the length of zero_vals is larger than the length of yvals,
# it means that there are wrong vals because of the identicial samples.
# Three and more identicial samples will make the yvals of spliting
# center into 0 and it will accidentally take it as leaves.
l_border = int(min(self.zero_vals))
r_border = int(max(self.zero_vals))
correct_leaves_pos = range(
l_border, r_border + 1, int((r_border - l_border) / len(yvals))
)
# Regenerating the leaves pos from the self.zero_vals with equally intervals.
self.zero_vals = [v for v in correct_leaves_pos]
self.zero_vals.sort()
self.layout = self.set_figure_layout(width, height)
self.data = dd_traces
def get_color_dict(self, colorscale):
"""
Returns colorscale used for dendrogram tree clusters.
:param (list) colorscale: Colors to use for the plot in rgb format.
:rtype (dict): A dict of default colors mapped to the user colorscale.
"""
# These are the color codes returned for dendrograms
# We're replacing them with nicer colors
# This list is the colors that can be used by dendrogram, which were
# determined as the combination of the default above_threshold_color and
# the default color palette (see scipy/cluster/hierarchy.py)
d = {
"r": "red",
"g": "green",
"b": "blue",
"c": "cyan",
"m": "magenta",
"y": "yellow",
"k": "black",
# TODO: 'w' doesn't seem to be in the default color
# palette in scipy/cluster/hierarchy.py
"w": "white",
}
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
if colorscale is None:
rgb_colorscale = [
"rgb(0,116,217)", # blue
"rgb(35,205,205)", # cyan
"rgb(61,153,112)", # green
"rgb(40,35,35)", # black
"rgb(133,20,75)", # magenta
"rgb(255,65,54)", # red
"rgb(255,255,255)", # white
"rgb(255,220,0)", # yellow
]
else:
rgb_colorscale = colorscale
for i in range(len(default_colors.keys())):
k = list(default_colors.keys())[i] # PY3 won't index keys
if i < len(rgb_colorscale):
default_colors[k] = rgb_colorscale[i]
# add support for cyclic format colors as introduced in scipy===1.5.0
# before this, the colors were named 'r', 'b', 'y' etc., now they are
# named 'C0', 'C1', etc. To keep the colors consistent regardless of the
# scipy version, we try as much as possible to map the new colors to the
# old colors
# this mapping was found by inpecting scipy/cluster/hierarchy.py (see
# comment above).
new_old_color_map = [
("C0", "b"),
("C1", "g"),
("C2", "r"),
("C3", "c"),
("C4", "m"),
("C5", "y"),
("C6", "k"),
("C7", "g"),
("C8", "r"),
("C9", "c"),
]
for nc, oc in new_old_color_map:
try:
default_colors[nc] = default_colors[oc]
except KeyError:
# it could happen that the old color isn't found (if a custom
# colorscale was specified), in this case we set it to an
# arbitrary default.
default_colors[nc] = "rgb(0,116,217)"
return default_colors
def set_axis_layout(self, axis_key):
"""
Sets and returns default axis object for dendrogram figure.
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
:rtype (dict): An axis_key dictionary with set parameters.
"""
axis_defaults = {
"type": "linear",
"ticks": "outside",
"mirror": "allticks",
"rangemode": "tozero",
"showticklabels": True,
"zeroline": False,
"showgrid": False,
"showline": True,
}
if len(self.labels) != 0:
axis_key_labels = self.xaxis
if self.orientation in ["left", "right"]:
axis_key_labels = self.yaxis
if axis_key_labels not in self.layout:
self.layout[axis_key_labels] = {}
self.layout[axis_key_labels]["tickvals"] = [
zv * self.sign[axis_key] for zv in self.zero_vals
]
self.layout[axis_key_labels]["ticktext"] = self.labels
self.layout[axis_key_labels]["tickmode"] = "array"
self.layout[axis_key].update(axis_defaults)
return self.layout[axis_key]
def set_figure_layout(self, width, height):
"""
Sets and returns default layout object for dendrogram figure.
"""
self.layout.update(
{
"showlegend": False,
"autosize": False,
"hovermode": "closest",
"width": width,
"height": height,
}
)
self.set_axis_layout(self.xaxis)
self.set_axis_layout(self.yaxis)
return self.layout
def get_dendrogram_traces(
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold, kwargs={}
):
"""
Calculates all the elements needed for plotting a dendrogram.
:param (ndarray) X: Matrix of observations as array of arrays
:param (list) colorscale: Color scale for dendrogram tree clusters
:param (function) distfun: Function to compute the pairwise distance
from the observations
:param (function) linkagefun: Function to compute the linkage matrix
from the pairwise distances
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
:rtype (tuple): Contains all the traces in the following order:
(a) trace_list: List of Plotly trace objects for dendrogram tree
(b) icoord: All X points of the dendrogram tree as array of arrays
with length 4
(c) dcoord: All Y points of the dendrogram tree as array of arrays
with length 4
(d) ordered_labels: leaf labels in the order they are going to
appear on the plot
(e) P['leaves']: left-to-right traversal of the leaves
"""
d = distfun(X)
Z = linkagefun(d)
P = sch.dendrogram(
Z,
orientation=self.orientation,
labels=self.labels,
no_plot=True,
color_threshold=color_threshold,
**kwargs
)
icoord = scp.array(P["icoord"])
dcoord = scp.array(P["dcoord"])
ordered_labels = scp.array(P["ivl"])
color_list = scp.array(P["color_list"])
colors = self.get_color_dict(colorscale)
trace_list = []
for i in range(len(icoord)):
# xs and ys are arrays of 4 points that make up the '∩' shapes
# of the dendrogram tree
if self.orientation in ["top", "bottom"]:
xs = icoord[i]
else:
xs = dcoord[i]
if self.orientation in ["top", "bottom"]:
ys = dcoord[i]
else:
ys = icoord[i]
color_key = color_list[i]
hovertext_label = None
if hovertext:
hovertext_label = hovertext[i]
trace = dict(
type="scatter",
x=np.multiply(self.sign[self.xaxis], xs),
y=np.multiply(self.sign[self.yaxis], ys),
mode="lines",
marker=dict(color=colors[color_key]),
text=hovertext_label,
hoverinfo="text",
)
try:
x_index = int(self.xaxis[-1])
except ValueError:
x_index = ""
try:
y_index = int(self.yaxis[-1])
except ValueError:
y_index = ""
trace["xaxis"] = "x" + x_index
trace["yaxis"] = "y" + y_index
trace_list.append(trace)
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
Example:
from modified_dendogram import create_dendrogram
import numpy as np
np.random.seed(1)
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = create_dendrogram(X)
fig.update_layout(width=800, height=500)
fig.show()
from utils.modified_dendogram import create_dendrogram
import numpy as np
np.random.seed(1)
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = create_dendrogram(X, truncate_mode="level", p=1)
fig.update_layout(width=800, height=500)
fig.show()
Related
I'm working on a visualization project in networkx and plotly. Is there a way to create a 3D graph that resembles how a human brain looks like in networkx and then to visualize it with plotly (so it will be interactive)?
The idea is to have the nodes on the outside (or only show the nodes if it's easier) and to color a set of them differently like the image above
To start, this code is heavily borrowed from Matteo Mancini, which he describes here and he has released under the MIT license.
In the original code, networkx is not used, so it's clear you don't actually need networkx to accomplish your goal. If this is not a strict requirement, I would consider using his original code and reworking it to fit your input data.
Since you listed networkx as a requirement, I simply reworked his code to take a networkx Graph object with certain node attributes such as 'color' and 'coord' to be used for those marker characteristics in the final plotly scatter. I just chose the first ten points in the dataset to color red, which is why they aren't grouped.
The full copy-pasteable code is below. The screenshot here obviously isn't interactive, but you can try the demo here on Google Colab.
To download files if in Jupyter notebook on Linux/Mac:
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/lh.pial.obj
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/icbm_fiber_mat.txt
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/fs_region_centers_68_sort.txt
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/freesurfer_regions_68_sort_full.txt
Otherwise: download the required files here.
Code:
import numpy as np
import plotly.graph_objects as go
import networkx as nx # New dependency
def obj_data_to_mesh3d(odata):
# odata is the string read from an obj file
vertices = []
faces = []
lines = odata.splitlines()
for line in lines:
slist = line.split()
if slist:
if slist[0] == 'v':
vertex = np.array(slist[1:], dtype=float)
vertices.append(vertex)
elif slist[0] == 'f':
face = []
for k in range(1, len(slist)):
face.append([int(s) for s in slist[k].replace('//','/').split('/')])
if len(face) > 3: # triangulate the n-polyonal face, n>3
faces.extend([[face[0][0]-1, face[k][0]-1, face[k+1][0]-1] for k in range(1, len(face)-1)])
else:
faces.append([face[j][0]-1 for j in range(len(face))])
else: pass
return np.array(vertices), np.array(faces)
with open("lh.pial.obj", "r") as f:
obj_data = f.read()
[vertices, faces] = obj_data_to_mesh3d(obj_data)
vert_x, vert_y, vert_z = vertices[:,:3].T
face_i, face_j, face_k = faces.T
cmat = np.loadtxt('icbm_fiber_mat.txt')
nodes = np.loadtxt('fs_region_centers_68_sort.txt')
labels=[]
with open("freesurfer_regions_68_sort_full.txt", "r") as f:
for line in f:
labels.append(line.strip('\n'))
# Instantiate Graph and add nodes (with their coordinates)
G = nx.Graph()
for idx, node in enumerate(nodes):
G.add_node(idx, coord=node)
# Add made-up colors for the nodes as node attribute
colors_data = {node: ('gray' if node > 10 else 'red') for node in G.nodes}
nx.set_node_attributes(G, colors_data, name="color")
# Add edges
[source, target] = np.nonzero(np.triu(cmat)>0.01)
edges = list(zip(source, target))
G.add_edges_from(edges)
# Get node coordinates from node attribute
nodes_x = [data['coord'][0] for node, data in G.nodes(data=True)]
nodes_y = [data['coord'][1] for node, data in G.nodes(data=True)]
nodes_z = [data['coord'][2] for node, data in G.nodes(data=True)]
edge_x = []
edge_y = []
edge_z = []
for s, t in edges:
edge_x += [nodes_x[s], nodes_x[t]]
edge_y += [nodes_y[s], nodes_y[t]]
edge_z += [nodes_z[s], nodes_z[t]]
# Get node colors from node attribute
node_colors = [data['color'] for node, data in G.nodes(data=True)]
fig = go.Figure()
# Changed color and opacity kwargs
fig.add_trace(go.Mesh3d(x=vert_x, y=vert_y, z=vert_z, i=face_i, j=face_j, k=face_k,
color='gray', opacity=0.1, name='', showscale=False, hoverinfo='none'))
fig.add_trace(go.Scatter3d(x=nodes_x, y=nodes_y, z=nodes_z, text=labels,
mode='markers', hoverinfo='text', name='Nodes',
marker=dict(
size=5, # Changed node size...
color=node_colors # ...and color
)
))
fig.add_trace(go.Scatter3d(x=edge_x, y=edge_y, z=edge_z,
mode='lines', hoverinfo='none', name='Edges',
opacity=0.3, # Added opacity kwarg
line=dict(color='pink') # Added line color
))
fig.update_layout(
scene=dict(
xaxis=dict(showticklabels=False, visible=False),
yaxis=dict(showticklabels=False, visible=False),
zaxis=dict(showticklabels=False, visible=False),
),
width=800, height=600
)
fig.show()
Based on the clarified requirements, I took a new approach:
Download accurate brain mesh data from BrainNet Viewer github repo;
Plot a random graph with 3D-coordinates using Kamada-Kuwai cost function in three dimensions centered in a sphere containing the brain mesh;
Radially expand the node positions away from the center of the brain mesh and then shift them back to the closest vertex actually on the brain mesh;
Color some nodes red based on an arbitrary distance criterion from a randomly selected mesh vertex;
Fiddle with a bunch of plotting parameters to make it look decent.
There is a clearly delineated spot to add in different graph data as well as change the logic by which the node colors are decided. The key parameters to play with so that things look decent after introducing new graph data are:
scale_factor: This changes how much the original Kamada-Kuwai calculated coordinates are translated radially away from the center of the brain mesh before they are snapped back to its surface. Larger values will make more nodes snap to the outer surface of the brain. Smaller values will leave more nodes positioned on the surfaces between the two hemispheres.
opacity of the lines in the edge trace: Graphs with more edges will quickly clutter up field of view and make the overall brain shape less visible. This speaks to my biggest dissatisfaction with this overall approach -- that edges which appear outside of the mesh surface make it harder to see the overall shape of the mesh, especially between the temporal lobes.
My other biggest caveat here is that there is no attempt has been made to check whether any nodes positioned on the brain surface happen to coincide or have any sort of equal spacing.
Here is a screenshot and the live demo on Colab. Full copy-pasteable code below.
There are a whole bunch of asides that could be discussed here, but for brevity I will only note two:
Folks interested in this topic but feeling overwhelmed by programming details should absolutely check out BrainNet Viewer;
There are plenty of other brain meshes in the BrainNet Viewer github repo that could be used. Even better, if you have any mesh which can be formatted or reworked to be compatible with this approach, you could at least try wrapping a set of nodes around any other non-brain and somewhat round-ish mesh representing any other object.
import plotly.graph_objects as go
import numpy as np
import networkx as nx
import math
def mesh_properties(mesh_coords):
"""Calculate center and radius of sphere minimally containing a 3-D mesh
Parameters
----------
mesh_coords : tuple
3-tuple with x-, y-, and z-coordinates (respectively) of 3-D mesh vertices
"""
radii = []
center = []
for coords in mesh_coords:
c_max = max(c for c in coords)
c_min = min(c for c in coords)
center.append((c_max + c_min) / 2)
radius = (c_max - c_min) / 2
radii.append(radius)
return(center, max(radii))
# Download and prepare dataset from BrainNet repo
coords = np.loadtxt(np.DataSource().open('https://raw.githubusercontent.com/mingruixia/BrainNet-Viewer/master/Data/SurfTemplate/BrainMesh_Ch2_smoothed.nv'), skiprows=1, max_rows=53469)
x, y, z = coords.T
triangles = np.loadtxt(np.DataSource().open('https://raw.githubusercontent.com/mingruixia/BrainNet-Viewer/master/Data/SurfTemplate/BrainMesh_Ch2_smoothed.nv'), skiprows=53471, dtype=int)
triangles_zero_offset = triangles - 1
i, j, k = triangles_zero_offset.T
# Generate 3D mesh. Simply replace with 'fig = go.Figure()' or turn opacity to zero if seeing brain mesh is not desired.
fig = go.Figure(data=[go.Mesh3d(x=x, y=y, z=z,
i=i, j=j, k=k,
color='lightpink', opacity=0.5, name='', showscale=False, hoverinfo='none')])
# Generate networkx graph and initial 3-D positions using Kamada-Kawai path-length cost-function inside sphere containing brain mesh
G = nx.gnp_random_graph(200, 0.02, seed=42) # Replace G with desired graph here
mesh_coords = (x, y, z)
mesh_center, mesh_radius = mesh_properties(mesh_coords)
scale_factor = 5 # Tune this value by hand to have more/fewer points between the brain hemispheres.
pos_3d = nx.kamada_kawai_layout(G, dim=3, center=mesh_center, scale=scale_factor*mesh_radius)
# Calculate final node positions on brain surface
pos_brain = {}
for node, position in pos_3d.items():
squared_dist_matrix = np.sum((coords - position) ** 2, axis=1)
pos_brain[node] = coords[np.argmin(squared_dist_matrix)]
# Prepare networkx graph positions for plotly node and edge traces
nodes_x = [position[0] for position in pos_brain.values()]
nodes_y = [position[1] for position in pos_brain.values()]
nodes_z = [position[2] for position in pos_brain.values()]
edge_x = []
edge_y = []
edge_z = []
for s, t in G.edges():
edge_x += [nodes_x[s], nodes_x[t]]
edge_y += [nodes_y[s], nodes_y[t]]
edge_z += [nodes_z[s], nodes_z[t]]
# Decide some more meaningful logic for coloring certain nodes. Currently the squared distance from the mesh point at index 42.
node_colors = []
for node in G.nodes():
if np.sum((pos_brain[node] - coords[42]) ** 2) < 1000:
node_colors.append('red')
else:
node_colors.append('gray')
# Add node plotly trace
fig.add_trace(go.Scatter3d(x=nodes_x, y=nodes_y, z=nodes_z,
#text=labels,
mode='markers',
#hoverinfo='text',
name='Nodes',
marker=dict(
size=5,
color=node_colors
)
))
# Add edge plotly trace. Comment out or turn opacity to zero if not desired.
fig.add_trace(go.Scatter3d(x=edge_x, y=edge_y, z=edge_z,
mode='lines',
hoverinfo='none',
name='Edges',
opacity=0.1,
line=dict(color='gray')
))
# Make axes invisible
fig.update_scenes(xaxis_visible=False,
yaxis_visible=False,
zaxis_visible=False)
# Manually adjust size of figure
fig.update_layout(autosize=False,
width=800,
height=800)
fig.show()
A possible way to do that:
import networkx as nx
import random
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from stl import mesh
# function to convert stl 3d-model to mesh
# Taken from : https://chart-studio.plotly.com/~empet/15276/converting-a-stl-mesh-to-plotly-gomes/#/
def stl2mesh3d(stl_mesh):
# stl_mesh is read by nympy-stl from a stl file; it is an array of faces/triangles (i.e. three 3d points)
# this function extracts the unique vertices and the lists I, J, K to define a Plotly mesh3d
p, q, r = stl_mesh.vectors.shape #(p, 3, 3)
# the array stl_mesh.vectors.reshape(p*q, r) can contain multiple copies of the same vertex;
# extract unique vertices from all mesh triangles
vertices, ixr = np.unique(stl_mesh.vectors.reshape(p*q, r), return_inverse=True, axis=0)
I = np.take(ixr, [3*k for k in range(p)])
J = np.take(ixr, [3*k+1 for k in range(p)])
K = np.take(ixr, [3*k+2 for k in range(p)])
return vertices, I, J, K
# Let's use a toy "brain" stl file. You can get it from my Dropbox: https://www.dropbox.com/s/lav2opci8vekaep/brain.stl?dl=0
#
# Note: I made it quick and dirty whith Blender and is not supposed to be an accurate representation
# of an actual brain. You can put your own model here.
my_mesh = mesh.Mesh.from_file('brain.stl')
vertices, I, J, K = stl2mesh3d(my_mesh)
x, y, z = vertices.T # x,y,z contain the stl vertices
# Let's generate a random spatial graph:
# Note: spatial graphs have a "pos" (position) attribute
# pos = nx.get_node_attributes(G, "pos")
G = nx.random_geometric_graph(30, 0.3, dim=3) # in dimension 3 --> pos = [x,y,z]
#nx.draw(G)
print('Nb. of nodes: ',G.number_of_nodes(), 'Nb. of edges: ',G.number_of_edges())
# Take G.number_of_nodes() of nodes and attribute them randomly to points in the list of vertices of the STL model:
# That is, we "scatter" the nodes on the brain surface:
Vec3dList=list(np.array(random.sample(list(vertices), G.number_of_nodes())))
for i in range(len(Vec3dList)):
G.nodes[i]['pos']=Vec3dList[i]
# Create nodes and edges graph objects:
# Code from: https://plotly.com/python/network-graphs/ modified to work with 3d graphs
edge_x = []
edge_y = []
edge_z = []
for edge in G.edges():
x0, y0, z0 = G.nodes[edge[0]]['pos']
x1, y1, z1 = G.nodes[edge[1]]['pos']
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(None)
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(None)
edge_z.append(z0)
edge_z.append(z1)
edge_z.append(None)
edge_trace = go.Scatter3d(
x=edge_x, y=edge_y, z=edge_z,
line=dict(width=2, color='#888'),
hoverinfo='none',
opacity=.3,
mode='lines')
node_x = []
node_y = []
node_z = []
for node in G.nodes():
X, Y, Z = G.nodes[node]['pos']
node_x.append(X)
node_y.append(Y)
node_z.append(Z)
node_trace = go.Scatter3d(
x=node_x, y=node_y,z=node_z,
mode='markers',
hoverinfo='text',
marker=dict(
showscale=True,
# colorscale options
#'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
#'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
#'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
colorscale='YlGnBu',
reversescale=True,
color=[],
size=5,
colorbar=dict(
thickness=15,
title='Node Connections',
xanchor='left',
titleside='right'
),
line_width=10))
node_adjacencies = []
node_text = []
for node, adjacencies in enumerate(G.adjacency()):
node_adjacencies.append(len(adjacencies[1]))
node_text.append('# of connections: '+str(len(adjacencies[1])))
node_trace.marker.color = node_adjacencies
node_trace.text = node_text
colorscale= [[0, '#e5dee5'], [1, '#e5dee5']]
mesh3D = go.Mesh3d(
x=x,
y=y,
z=z,
i=I,
j=J,
k=K,
flatshading=False,
colorscale=colorscale,
intensity=z,
name='Brain',
opacity=0.25,
hoverinfo='none',
showscale=False)
title = "Brain"
layout = go.Layout(paper_bgcolor='rgb(1,1,1)',
title_text=title, title_x=0.5,
font_color='white',
width=800,
height=800,
scene_camera=dict(eye=dict(x=1.25, y=-1.25, z=1)),
scene_xaxis_visible=False,
scene_yaxis_visible=False,
scene_zaxis_visible=False)
fig = go.Figure(data=[mesh3D, edge_trace, node_trace], layout=layout)
fig.data[0].update(lighting=dict(ambient= .2,
diffuse= 1,
fresnel= 1,
specular= 1,
roughness= .1,
facenormalsepsilon=0))
fig.data[0].update(lightposition=dict(x=3000,
y=3000,
z=10000));
fig.show()
Below, the result. As you can see the result is not that great... But, maybe, you can improve on it.
Best regards
Vispy library might be useful https://github.com/vispy/vispy.
I think you can use the following examples.
3D brain mesh viewer
1ex output
Plot various views of a structural MRI.
2ex output
Clipping planes with volume and markers
3ex output
These examples are interactive.
Regards!
I have two sets of satellite data. For both sets, I have the pixel geometry (latitude and longitude of each corner of the pixel). I would like to regrid one set to the other. Thus, my goal is area-weighted regridding from an irregular grid to another irregular grid. I am aware of xESMF, but am unsure if that is the best tool for the job. Perhaps iris area weighting regrid would be appropriate?
I've ran into similar things in the past. I'm on Windows, and xEMSF wasn't really an option for me.
I've written this package, and added some methods for computing grid to grid weights:
https://github.com/Deltares/numba_celltree
(You can pip install it.)
The data structure can deal with fully unstructured 2D meshes, and expects the data in such a format. See the code below.
You will need to make some changes: your coordinates aren't named x and y most likely. You will also need to update the ugrid2d_topology function somewhat, since I'm assuming regular quadrilateral grids here (but they're irregular when seen in each others coordinate system).
It's still pretty straightforward, just make sure you have 2D array of vertices, and a face_node_connectivity array of shape (n_cell, 4) which maps for every face its four vertices. See this documention for a little more background:
https://ugrid-conventions.github.io/ugrid-conventions/
import numpy as np
import pandas as pd
import pyproj
import xarray as xr
from numba_celltree import CellTree2d
FloatArray = np.ndarray
IntArray = np.ndarray
def _coord(da, dim):
"""
Transform N xarray midpoints into N + 1 vertex edges
"""
delta_dim = "d" + dim # e.g. dx, dy, dz, etc.
# If empty array, return empty
if da[dim].size == 0:
return np.array(())
if delta_dim in da.coords: # equidistant or non-equidistant
dx = da[delta_dim].values
if dx.shape == () or dx.shape == (1,): # scalar -> equidistant
dxs = np.full(da[dim].size, dx)
else: # array -> non-equidistant
dxs = dx
_check_monotonic(dxs, dim)
else: # undefined -> equidistant
if da[dim].size == 1:
raise ValueError(
f"DataArray has size 1 along {dim}, so cellsize must be provided"
" as a coordinate."
)
dxs = np.diff(da[dim].values)
dx = dxs[0]
atolx = abs(1.0e-4 * dx)
if not np.allclose(dxs, dx, atolx):
raise ValueError(
f"DataArray has to be equidistant along {dim}, or cellsizes"
" must be provided as a coordinate."
)
dxs = np.full(da[dim].size, dx)
dxs = np.abs(dxs)
x = da[dim].values
if not da.indexes[dim].is_monotonic_increasing:
x = x[::-1]
dxs = dxs[::-1]
# This assumes the coordinate to be monotonic increasing
x0 = x[0] - 0.5 * dxs[0]
x = np.full(dxs.size + 1, x0)
x[1:] += np.cumsum(dxs)
return x
def _ugrid2d_dataset(
node_x: FloatArray,
node_y: FloatArray,
face_x: FloatArray,
face_y: FloatArray,
face_nodes: IntArray,
) -> xr.Dataset:
ds = xr.Dataset()
ds["mesh2d"] = xr.DataArray(
data=0,
attrs={
"cf_role": "mesh_topology",
"long_name": "Topology data of 2D mesh",
"topology_dimension": 2,
"node_coordinates": "node_x node_y",
"face_node_connectivity": "face_nodes",
"edge_node_connectivity": "edge_nodes",
},
)
ds = ds.assign_coords(
node_x=xr.DataArray(
data=node_x,
dims=["node"],
)
)
ds = ds.assign_coords(
node_y=xr.DataArray(
data=node_y,
dims=["node"],
)
)
ds["face_nodes"] = xr.DataArray(
data=face_nodes,
coords={
"face_x": ("face", face_x),
"face_y": ("face", face_y),
},
dims=["face", "nmax_face"],
attrs={
"cf_role": "face_node_connectivity",
"long_name": "Vertex nodes of mesh faces (counterclockwise)",
"start_index": 0,
"_FillValue": -1,
},
)
ds.attrs = {"Conventions": "CF-1.8 UGRID-1.0"}
return ds
def ugrid2d_topology(data: Union[xr.DataArray, xr.Dataset]) -> xr.Dataset:
"""
Derive the 2D-UGRID quadrilateral mesh topology from a structured DataArray
or Dataset, with (2D-dimensions) "y" and "x".
Parameters
----------
data: Union[xr.DataArray, xr.Dataset]
Structured data from which the "x" and "y" coordinate will be used to
define the UGRID-2D topology.
Returns
-------
ugrid_topology: xr.Dataset
Dataset with the required arrays describing 2D unstructured topology:
node_x, node_y, face_x, face_y, face_nodes (connectivity).
"""
# Transform midpoints into vertices
# These are always returned monotonically increasing
x = data["x"].values
xcoord = _coord(data, "x")
if not data.indexes["x"].is_monotonic_increasing:
xcoord = xcoord[::-1]
y = data["y"].values
ycoord = _coord(data, "y")
if not data.indexes["y"].is_monotonic_increasing:
ycoord = ycoord[::-1]
# Compute all vertices, these are the ugrid nodes
node_y, node_x = (a.ravel() for a in np.meshgrid(ycoord, xcoord, indexing="ij"))
face_y, face_x = (a.ravel() for a in np.meshgrid(y, x, indexing="ij"))
linear_index = np.arange(node_x.size, dtype=np.int32).reshape(
ycoord.size, xcoord.size
)
# Allocate face_node_connectivity
nfaces = (ycoord.size - 1) * (xcoord.size - 1)
face_nodes = np.empty((nfaces, 4))
# Set connectivity in counterclockwise manner
face_nodes[:, 0] = linear_index[:-1, 1:].ravel() # upper right
face_nodes[:, 1] = linear_index[:-1, :-1].ravel() # upper left
face_nodes[:, 2] = linear_index[1:, :-1].ravel() # lower left
face_nodes[:, 3] = linear_index[1:, 1:].ravel() # lower right
# Tie it together
ds = _ugrid2d_dataset(node_x, node_y, face_x, face_y, face_nodes)
return ds
def area_weighted_mean(
da: xr.DataArray,
destination_index: np.ndarray,
source_index: np.ndarray,
weights: np.ndarray,
):
"""
Area weighted mean.
Parameters
----------
da: xr.DataArray
Contains source data.
destination_index: np.ndarray
In which destination the overlap is located.
source_index: np.ndarray
In which source cell the overlap is located.
weights: np.ndarray
Area of each overlap.
Returns
-------
destination_index: np.ndarray
values: np.ndarray
"""
values = da.data.ravel()[source_index]
df = pd.DataFrame(
{"dst": destination_index, "area": weights, "av": weights * values}
)
aggregated = df.groupby("dst").sum("sum", min_count=1)
out = aggregated["av"] / aggregated["area"]
return out.index.values, out.values
class Regridder:
"""
Regridder to reproject and/or regrid rasters. When no ``crs_source`` and
``crs_destination`` are provided, it is assumed that ``source`` and
``destination`` share the same coordinate system.
Note that an area weighted regridding method only makes sense for projected
(Cartesian!) coordinate systems.
Parameters
----------
source: xr.DataArray
Source example. Must have dimensions ("y", "x").
destination: xr.DataArray
Destination example. Must have dimensions ("y", "x").
crs_source: optional, default: None
crs_destination: optional, default: None
"""
def __init__(
self,
source: xr.DataArray,
destination: xr.DataArray,
crs_source=None,
crs_destination=None,
):
src = ugrid2d_topology(source)
dst = ugrid2d_topology(destination)
src_yy = src["node_y"].values
src_xx = src["node_x"].values
if crs_source and crs_destination:
transformer = pyproj.Transformer.from_crs(
crs_from=crs_source, crs_to=crs_destination, always_xy=True
)
src_xx, src_yy = transformer.transform(xx=src_xx, yy=src_yy)
elif crs_source ^ crs_destination:
raise ValueError("Received only one of (crs_source, crs_destination)")
src_vertices = np.column_stack([src_xx, src_yy])
src_faces = src["face_nodes"].values.astype(int)
dst_vertices = np.column_stack((dst["node_x"].values, dst["node_y"].values))
dst_faces = dst["face_nodes"].values
celltree = CellTree2d(src_vertices, src_faces, fill_value=-1)
self.source = source.copy()
self.destination = destination.copy()
(
self.destination_index,
self.source_index,
self.weights,
) = celltree.intersect_faces(
dst_vertices,
dst_faces,
fill_value=-1,
)
def regrid(self, da: xr.DataArray, fill_value=np.nan):
"""
Parameters
----------
da: xr.DataArray
Data to regrid.
fill_value: optional, default: np.nan
Default value of the output grid, e.g. where no overlap occurs.
Returns
-------
regridded: xr.DataArray
Data of da, regridded using an area weighted mean.
"""
src = self.source
if not (np.allclose(da["y"], src["y"]) and np.allclose(da["x"], src["x"])):
raise ValueError("da does not match source")
index, values = area_weighted_mean(
da,
self.destination_index,
self.source_index,
self.weights,
)
data = np.full(self.destination.shape, fill_value)
data.ravel()[index] = values
out = self.destination.copy(data=data)
out.name = da.name
return out
# Example use
da = xr.open_dataarray("gw_abstraction_sum.nc")
like = xr.open_dataarray("example.nc")
regridder = Regridder(
source=da, destination=like, crs_source=4326, crs_destination=3035
)
result = regridder.regrid(da)
result.to_netcdf("area-weighted_sum.nc")
Is there a way in Plotly to access colormap colours at any value along its range?
I know I can access the defining colours for a colourscale from
plotly.colors.PLOTLY_SCALES["Viridis"]
but I am unable to find how to access intermediate / interpolated values.
The equivalent in Matplotlib is shown in this question. There is also another question that address a similar question from the colorlover library, but neither offers a nice solution.
Plotly does not appear to have such a method, so I wrote one:
import plotly.colors
def get_continuous_color(colorscale, intermed):
"""
Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
color for any value in that range.
Plotly doesn't make the colorscales directly accessible in a common format.
Some are ready to use:
colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
Others are just swatches that need to be constructed into a colorscale:
viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
:param colorscale: A plotly continuous colorscale defined with RGB string colors.
:param intermed: value in the range [0, 1]
:return: color in rgb string format
:rtype: str
"""
if len(colorscale) < 1:
raise ValueError("colorscale must have at least one color")
if intermed <= 0 or len(colorscale) == 1:
return colorscale[0][1]
if intermed >= 1:
return colorscale[-1][1]
for cutoff, color in colorscale:
if intermed > cutoff:
low_cutoff, low_color = cutoff, color
else:
high_cutoff, high_color = cutoff, color
break
# noinspection PyUnboundLocalVariable
return plotly.colors.find_intermediate_color(
lowcolor=low_color, highcolor=high_color,
intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
colortype="rgb")
The challenge is that the built-in Plotly colorscales are not consistently exposed. Some are defined as a colorscale already, others as just a list of color swatches that must be converted to a color scale first.
The Viridis colorscale is defined with hex values, which the Plotly color manipulation methods don't like, so it's easiest to construct it from swatches like this:
viridis_colors, _ = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors)
get_continuous_color(colorscale, intermed=0.25)
# rgb(58.75, 80.75, 138.25)
There is a built in method from plotly.express.colors to sample_colorscale which would provide the color samples:
from plotly.express.colors import sample_colorscale
import plotly.graph_objects as go
import numpy as np
x = np.linspace(0, 1, 25)
c = sample_colorscale('jet', list(x))
fig = go.FigureWidget()
fig.add_trace(
go.Bar(x=x, y=y, marker_color=c)
)
fig.show()
See the output figure -> sampled_colors
This answer extend the already good one provided by Adam. In particular, it deals with the inconsistency of Plotly's color scales.
In Plotly, you specify a built-in color scale by writing colorscale="name_of_the_colorscale". This suggests that Plotly already has a built-in tool that somehow convert the color scale to an appropriate value and is capable of dealing with these inconsistencies. By searching Plotly's source code we find the useful ColorscaleValidator class. Let's see how to use it:
def get_color(colorscale_name, loc):
from _plotly_utils.basevalidators import ColorscaleValidator
# first parameter: Name of the property being validated
# second parameter: a string, doesn't really matter in our use case
cv = ColorscaleValidator("colorscale", "")
# colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...]
colorscale = cv.validate_coerce(colorscale_name)
if hasattr(loc, "__iter__"):
return [get_continuous_color(colorscale, x) for x in loc]
return get_continuous_color(colorscale, loc)
# Identical to Adam's answer
import plotly.colors
from PIL import ImageColor
def get_continuous_color(colorscale, intermed):
"""
Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
color for any value in that range.
Plotly doesn't make the colorscales directly accessible in a common format.
Some are ready to use:
colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
Others are just swatches that need to be constructed into a colorscale:
viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
:param colorscale: A plotly continuous colorscale defined with RGB string colors.
:param intermed: value in the range [0, 1]
:return: color in rgb string format
:rtype: str
"""
if len(colorscale) < 1:
raise ValueError("colorscale must have at least one color")
hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))
if intermed <= 0 or len(colorscale) == 1:
c = colorscale[0][1]
return c if c[0] != "#" else hex_to_rgb(c)
if intermed >= 1:
c = colorscale[-1][1]
return c if c[0] != "#" else hex_to_rgb(c)
for cutoff, color in colorscale:
if intermed > cutoff:
low_cutoff, low_color = cutoff, color
else:
high_cutoff, high_color = cutoff, color
break
if (low_color[0] == "#") or (high_color[0] == "#"):
# some color scale names (such as cividis) returns:
# [[loc1, "hex1"], [loc2, "hex2"], ...]
low_color = hex_to_rgb(low_color)
high_color = hex_to_rgb(high_color)
return plotly.colors.find_intermediate_color(
lowcolor=low_color,
highcolor=high_color,
intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
colortype="rgb",
)
At this point, all you have to do is:
get_color("phase", 0.5)
# 'rgb(123.99999999999999, 112.00000000000001, 236.0)'
import numpy as np
get_color("phase", np.linspace(0, 1, 256))
# ['rgb(167, 119, 12)',
# 'rgb(168.2941176470588, 118.0078431372549, 13.68235294117647)',
# ...
Edit: improvements to deal with special cases.
The official reference explains. Here
import plotly.express as px
print(px.colors.sequential.Viridis)
['#440154', '#482878', '#3e4989', '#31688e', '#26828e', '#1f9e89', '#35b779', '#6ece58', '#b5de2b', '#fde725']
print(px.colors.sequential.Viridis[0])
#440154
import plotly.express as px
color_list = list(name_of_color_scale)
# name_of_color_scale could be any in-built colorscale like px.colors.qualitative.D3.
Output:
color_list =
['#1F77B4',
'#FF7F0E',
'#2CA02C',
'#D62728',
'#9467BD',
'#8C564B',
'#E377C2',
'#7F7F7F',
'#BCBD22',
'#17BECF']
I would like to plot parallel lines with different colors. E.g. rather than a single red line of thickness 6, I would like to have two parallel lines of thickness 3, with one red and one blue.
Any thoughts would be appreciated.
Merci
Even with the smart offsetting (s. below), there is still an issue in a view that has sharp angles between consecutive points.
Zoomed view of smart offsetting:
Overlaying lines of varying thickness:
Plotting parallel lines is not an easy task. Using a simple uniform offset will of course not show the desired result. This is shown in the left picture below.
Such a simple offset can be produced in matplotlib as shown in the transformation tutorial.
Method1
A better solution may be to use the idea sketched on the right side. To calculate the offset of the nth point we can use the normal vector to the line between the n-1st and the n+1st point and use the same distance along this normal vector to calculate the offset point.
The advantage of this method is that we have the same number of points in the original line as in the offset line. The disadvantage is that it is not completely accurate, as can be see in the picture.
This method is implemented in the function offset in the code below.
In order to make this useful for a matplotlib plot, we need to consider that the linewidth should be independent of the data units. Linewidth is usually given in units of points, and the offset would best be given in the same unit, such that e.g. the requirement from the question ("two parallel lines of width 3") can be met.
The idea is therefore to transform the coordinates from data to display coordinates, using ax.transData.transform. Also the offset in points o can be transformed to the same units: Using the dpi and the standard of ppi=72, the offset in display coordinates is o*dpi/ppi. After the offset in display coordinates has been applied, the inverse transform (ax.transData.inverted().transform) allows a backtransformation.
Now there is another dimension of the problem: How to assure that the offset remains the same independent of the zoom and size of the figure?
This last point can be addressed by recalculating the offset each time a zooming of resizing event has taken place.
Here is how a rainbow curve would look like produced by this method.
And here is the code to produce the image.
import numpy as np
import matplotlib.pyplot as plt
dpi = 100
def offset(x,y, o):
""" Offset coordinates given by array x,y by o """
X = np.c_[x,y].T
m = np.array([[0,-1],[1,0]])
R = np.zeros_like(X)
S = X[:,2:]-X[:,:-2]
R[:,1:-1] = np.dot(m, S)
R[:,0] = np.dot(m, X[:,1]-X[:,0])
R[:,-1] = np.dot(m, X[:,-1]-X[:,-2])
On = R/np.sqrt(R[0,:]**2+R[1,:]**2)*o
Out = On+X
return Out[0,:], Out[1,:]
def offset_curve(ax, x,y, o):
""" Offset array x,y in data coordinates
by o in points """
trans = ax.transData.transform
inv = ax.transData.inverted().transform
X = np.c_[x,y]
Xt = trans(X)
xto, yto = offset(Xt[:,0],Xt[:,1],o*dpi/72. )
Xto = np.c_[xto, yto]
Xo = inv(Xto)
return Xo[:,0], Xo[:,1]
# some single points
y = np.array([1,2,2,3,3,0])
x = np.arange(len(y))
#or try a sinus
x = np.linspace(0,9)
y=np.sin(x)*x/3.
fig, ax=plt.subplots(figsize=(4,2.5), dpi=dpi)
cols = ["#fff40b", "#00e103", "#ff9921", "#3a00ef", "#ff2121", "#af00e7"]
lw = 2.
lines = []
for i in range(len(cols)):
l, = plt.plot(x,y, lw=lw, color=cols[i])
lines.append(l)
def plot_rainbow(event=None):
xr = range(6); yr = range(6);
xr[0],yr[0] = offset_curve(ax, x,y, lw/2.)
xr[1],yr[1] = offset_curve(ax, x,y, -lw/2.)
xr[2],yr[2] = offset_curve(ax, xr[0],yr[0], lw)
xr[3],yr[3] = offset_curve(ax, xr[1],yr[1], -lw)
xr[4],yr[4] = offset_curve(ax, xr[2],yr[2], lw)
xr[5],yr[5] = offset_curve(ax, xr[3],yr[3], -lw)
for i in range(6):
lines[i].set_data(xr[i], yr[i])
plot_rainbow()
fig.canvas.mpl_connect("resize_event", plot_rainbow)
fig.canvas.mpl_connect("button_release_event", plot_rainbow)
plt.savefig(__file__+".png", dpi=dpi)
plt.show()
Method2
To avoid overlapping lines, one has to use a more complicated solution.
One could first offset every point normal to the two line segments it is part of (green points in the picture below). Then calculate the line through those offset points and find their intersection.
A particular case would be when the slopes of two subsequent line segments equal. This has to be taken care of (eps in the code below).
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
dpi = 100
def intersect(p1, p2, q1, q2, eps=1.e-10):
""" given two lines, first through points pn, second through qn,
find the intersection """
x1 = p1[0]; y1 = p1[1]; x2 = p2[0]; y2 = p2[1]
x3 = q1[0]; y3 = q1[1]; x4 = q2[0]; y4 = q2[1]
nomX = ((x1*y2-y1*x2)*(x3-x4)- (x1-x2)*(x3*y4-y3*x4))
denom = float( (x1-x2)*(y3-y4) - (y1-y2)*(x3-x4) )
nomY = (x1*y2-y1*x2)*(y3-y4) - (y1-y2)*(x3*y4-y3*x4)
if np.abs(denom) < eps:
#print "intersection undefined", p1
return np.array( p1 )
else:
return np.array( [ nomX/denom , nomY/denom ])
def offset(x,y, o, eps=1.e-10):
""" Offset coordinates given by array x,y by o """
X = np.c_[x,y].T
m = np.array([[0,-1],[1,0]])
S = X[:,1:]-X[:,:-1]
R = np.dot(m, S)
norm = np.sqrt(R[0,:]**2+R[1,:]**2) / o
On = R/norm
Outa = On+X[:,1:]
Outb = On+X[:,:-1]
G = np.zeros_like(X)
for i in xrange(0, len(X[0,:])-2):
p = intersect(Outa[:,i], Outb[:,i], Outa[:,i+1], Outb[:,i+1], eps=eps)
G[:,i+1] = p
G[:,0] = Outb[:,0]
G[:,-1] = Outa[:,-1]
return G[0,:], G[1,:]
def offset_curve(ax, x,y, o, eps=1.e-10):
""" Offset array x,y in data coordinates
by o in points """
trans = ax.transData.transform
inv = ax.transData.inverted().transform
X = np.c_[x,y]
Xt = trans(X)
xto, yto = offset(Xt[:,0],Xt[:,1],o*dpi/72., eps=eps )
Xto = np.c_[xto, yto]
Xo = inv(Xto)
return Xo[:,0], Xo[:,1]
# some single points
y = np.array([1,1,2,0,3,2,1.,4,3]) *1.e9
x = np.arange(len(y))
x[3]=x[4]
#or try a sinus
#x = np.linspace(0,9)
#y=np.sin(x)*x/3.
fig, ax=plt.subplots(figsize=(4,2.5), dpi=dpi)
cols = ["r", "b"]
lw = 11.
lines = []
for i in range(len(cols)):
l, = plt.plot(x,y, lw=lw, color=cols[i], solid_joinstyle="miter")
lines.append(l)
def plot_rainbow(event=None):
xr = range(2); yr = range(2);
xr[0],yr[0] = offset_curve(ax, x,y, lw/2.)
xr[1],yr[1] = offset_curve(ax, x,y, -lw/2.)
for i in range(2):
lines[i].set_data(xr[i], yr[i])
plot_rainbow()
fig.canvas.mpl_connect("resize_event", plot_rainbow)
fig.canvas.mpl_connect("button_release_event", plot_rainbow)
plt.show()
Note that this method should work well as long as the offset between the lines is smaller then the distance between subsequent points on the line. Otherwise method 1 may be better suited.
The best that I can think of is to take your data, generate a series of small offsets, and use fill_between to make bands of whatever color you like.
I wrote a function to do this. I don't know what shape you're trying to plot, so this may or may not work for you. I tested it on a parabola and got decent results. You can also play around with the list of colors.
def rainbow_plot(x, y, spacing=0.1):
fig, ax = plt.subplots()
colors = ['red', 'yellow', 'green', 'cyan','blue']
top = max(y)
lines = []
for i in range(len(colors)+1):
newline_data = y - top*spacing*i
lines.append(newline_data)
for i, c in enumerate(colors):
ax.fill_between(x, lines[i], lines[i+1], facecolor=c)
return fig, ax
x = np.linspace(0,1,51)
y = 1-(x-0.5)**2
rainbow_plot(x,y)
I have figured out a method to cluster disperse point data into structured 2-d array(like rasterize function). And I hope there are some better ways to achieve that target.
My work
1. Intro
1000 point data has there dimensions of properties (lon, lat, emission) whicn represent one factory located at (x,y) emit certain amount of CO2 into atmosphere
grid network: predefine the 2-d array in the shape of 20x20
http://i4.tietuku.com/02fbaf32d2f09fff.png
The code reproduced here:
#### define the map area
xc1,xc2,yc1,yc2 = 113.49805889531724,115.5030664238035,37.39995194888143,38.789235929357105
map = Basemap(llcrnrlon=xc1,llcrnrlat=yc1,urcrnrlon=xc2,urcrnrlat=yc2)
#### reading the point data and scatter plot by their position
df = pd.read_csv("xxxxx.csv")
px,py = map(df.lon, df.lat)
map.scatter(px, py, color = "red", s= 5,zorder =3)
#### predefine the grid networks
lon_grid,lat_grid = np.linspace(xc1,xc2,21), np.linspace(yc1,yc2,21)
lon_x,lat_y = np.meshgrid(lon_grid,lat_grid)
grids = np.zeros(20*20).reshape(20,20)
plt.pcolormesh(lon_x,lat_y,grids,cmap = 'gray', facecolor = 'none',edgecolor = 'k',zorder=3)
2. My target
Finding the nearest grid point for each factory
Add the emission data into this grid number
3. Algorithm realization
3.1 Raster grid
note: 20x20 grid points are distributed in this area represented by blue dot.
http://i4.tietuku.com/8548554587b0cb3a.png
3.2 KD-tree
Find the nearest blue dot of each red point
sh = (20*20,2)
grids = np.zeros(20*20*2).reshape(*sh)
sh_emission = (20*20)
grids_em = np.zeros(20*20).reshape(sh_emission)
k = 0
for j in range(0,yy.shape[0],1):
for i in range(0,xx.shape[0],1):
grids[k] = np.array([lon_grid[i],lat_grid[j]])
k+=1
T = KDTree(grids)
x_delta = (lon_grid[2] - lon_grid[1])
y_delta = (lat_grid[2] - lat_grid[1])
R = np.sqrt(x_delta**2 + y_delta**2)
for i in range(0,len(df.lon),1):
idx = T.query_ball_point([df.lon.iloc[i],df.lat.iloc[i]], r=R)
# there are more than one blue dot which are founded sometimes,
# So I'll calculate the distances between the factory(red point)
# and all blue dots which are listed
if (idx > 1):
distance = []
for k in range(0,len(idx),1):
distance.append(np.sqrt((df.lon.iloc[i] - grids[k][0])**2 + (df.lat.iloc[i] - grids[k][1])**2))
pos_index = distance.index(min(distance))
pos = idx[pos_index]
# Only find 1 point
else:
pos = idx
grids_em[pos] += df.so2[i]
4. Result
co2 = grids_em.reshape(20,20)
plt.pcolormesh(lon_x,lat_y,co2,cmap =plt.cm.Spectral_r,zorder=3)
http://i4.tietuku.com/6ded65c4ac301294.png
5. My question
Can someone point out some drawbacks or error of this method?
Is there some algorithms more aligned with my target?
Thanks a lot!
There are many for-loop in your code, it's not the numpy way.
Make some sample data first:
import numpy as np
import pandas as pd
from scipy.spatial import KDTree
import pylab as pl
xc1, xc2, yc1, yc2 = 113.49805889531724, 115.5030664238035, 37.39995194888143, 38.789235929357105
N = 1000
GSIZE = 20
x, y = np.random.multivariate_normal([(xc1 + xc2)*0.5, (yc1 + yc2)*0.5], [[0.1, 0.02], [0.02, 0.1]], size=N).T
value = np.ones(N)
df_points = pd.DataFrame({"x":x, "y":y, "v":value})
For equal space grids you can use hist2d():
pl.hist2d(df_points.x, df_points.y, weights=df_points.v, bins=20, cmap="viridis");
Here is the output:
Here is the code to use KdTree:
X, Y = np.mgrid[x.min():x.max():GSIZE*1j, y.min():y.max():GSIZE*1j]
grid = np.c_[X.ravel(), Y.ravel()]
points = np.c_[df_points.x, df_points.y]
tree = KDTree(grid)
dist, indices = tree.query(points)
grid_values = df_points.groupby(indices).v.sum()
df_grid = pd.DataFrame(grid, columns=["x", "y"])
df_grid["v"] = grid_values
fig, ax = pl.subplots(figsize=(10, 8))
ax.plot(df_points.x, df_points.y, "kx", alpha=0.2)
mapper = ax.scatter(df_grid.x, df_grid.y, c=df_grid.v,
cmap="viridis",
linewidths=0,
s=100, marker="o")
pl.colorbar(mapper, ax=ax);
the output is: