Adding seaborn clustermap to figure with other plots - python

I am trying to put the following two plots onto the same figure:
import seaborn as sns; sns.set(color_codes=True)
import matplotlib.pyplot as plt
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
iris = sns.load_dataset("iris")
sns.boxplot(data=iris, orient="h", palette="Set2", ax = ax1)
species = iris.pop("species")
lut = dict(zip(species.unique(), "rbg"))
row_colors = species.map(lut)
sns.clustermap(iris, row_colors=row_colors, ax = ax2)
I understand that clustermap returns a figure, so this doesn't work. However, I still need a way to present these plots next to each other (horizontal). sns.heatmap returns an axes, but it does not support clustering or color annotation .
What is the best way to do this ?

Indeed, clustermap, as some other seaborn functions, creates its own figure. There is nothing you can do about that but as long as all other content you want to have in the final figure can be created inside axes, like in this case the boxplot, the solution is relatively easy.
You can simply work with the figure that clustermap has created for you. The idea would then be to manipulate the gridspec of the axes such that there is some place left for the other axes.
import seaborn as sns; sns.set(color_codes=True)
import matplotlib.pyplot as plt
import matplotlib.gridspec
iris = sns.load_dataset("iris")
species = iris.pop("species")
lut = dict(zip(species.unique(), "rbg"))
row_colors = species.map(lut)
#First create the clustermap figure
g = sns.clustermap(iris, row_colors=row_colors, figsize=(13,8))
# set the gridspec to only cover half of the figure
g.gs.update(left=0.05, right=0.45)
#create new gridspec for the right part
gs2 = matplotlib.gridspec.GridSpec(1,1, left=0.6)
# create axes within this new gridspec
ax2 = g.fig.add_subplot(gs2[0])
# plot boxplot in the new axes
sns.boxplot(data=iris, orient="h", palette="Set2", ax = ax2)
plt.show()
For the case when having multiple figure-level functions to combine the solution is much more complicated, as seen e.g. in this question.

Related

Removing legend from mpl parallel coordinates plot?

I have a parallel coordinates plot with lots of data points so I'm trying to use a continuous colour bar to represent that, which I think I have worked out. However, I haven't been able to remove the default key that is put in when creating the plot, which is very long and hinders readability. Is there a way to remove this table to make the graph much easier to read?
This is the code I'm currently using to generate the parallel coordinates plot:
parallel_coordinates(data[[' male_le','
female_le','diet','activity','obese_perc','median_income']],'median_income',colormap = 'rainbow',
alpha = 0.5)
fig, ax = plt.subplots(figsize=(6, 1))
fig.subplots_adjust(bottom=0.5)
cmap = mpl.cm.rainbow
bounds = [0.00,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N,)
plt.colorbar(mpl.cm.ScalarMappable(norm = norm, cmap=cmap),cax = ax, orientation = 'horizontal',
label = 'normalised median income', alpha = 0.5)
plt.show()
Current Output:
I want my legend to be represented as a color bar, like this:
Any help would be greatly appreciated. Thanks.
You can use ax.legend_.remove() to remove the legend.
The cax parameter of plt.colorbar indicates the subplot where to put the colorbar. If you leave it out, matplotlib will create a new subplot, "stealing" space from the current subplot (subplots are often referenced to by ax in matplotlib). So, here leaving out cax (adding ax=ax isn't necessary, as here ax is the current subplot) will create the desired colorbar.
The code below uses seaborn's penguin dataset to create a standalone example.
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import numpy as np
from pandas.plotting import parallel_coordinates
penguins = sns.load_dataset('penguins')
fig, ax = plt.subplots(figsize=(10, 4))
cmap = plt.get_cmap('rainbow')
bounds = np.arange(penguins['body_mass_g'].min(), penguins['body_mass_g'].max() + 200, 200)
norm = mpl.colors.BoundaryNorm(bounds, 256)
penguins = penguins.dropna(subset=['body_mass_g'])
parallel_coordinates(penguins[['bill_length_mm', 'bill_depth_mm', 'flipper_length_mm', 'body_mass_g']],
'body_mass_g', colormap=cmap, alpha=0.5, ax=ax)
ax.legend_.remove()
plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
ax=ax, orientation='horizontal', label='body mass', alpha=0.5)
plt.show()

Seaborn heatmaps in subplots - align x-axis

I am trying to plot a figure containing two subplots, a seaborn heatmap and simple matplotlib lines. However, when sharing the x-axis for both plots, they do not align as can be seen in this figure:
It would seem that the problem is similar to this post, but when displaying ax[0].get_xticks() and ax[1].get_xticks() I get the same positions, so I don't know what to change. And in my picture the the deviation seems to be more than a 0.5 shift.
What am I doing wrong?
The code I used to plot the figure is the following:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
M_1=np.random.random((15,15))
M_2=np.random.random((15,15))
L_1=np.random.random(15)
L_2=np.random.random(15)
x=range(15)
cmap = sns.color_palette("hot", 100)
sns.set(style="white")
fig, ax = plt.subplots(2, 1, sharex='col', figsize=(10, 12))
ax[0].plot(x,L_1,'-', marker='o',color='tab:orange')
sns.heatmap(M_1, cmap=cmap, vmax=np.max(M_1), center=np.max(M_1)/2., square=False, ax=ax[1])
#Mr-T 's comment is spot on. The easiest would be to create the axes beforehand instead of letting heatmap() shrink your axes in order to make room for the colorbar.
There is the added complication that the labels for the heatmap are not actually placed at [0,1,...] but are in the middle of each cell at [0.5, 1.5, ...]. So if you want your upper plot to align with the labels at the bottom (and with the center of each cell), you may have to shift your plot by 0.5 units to the right:
M_1=np.random.random((15,15))
M_2=np.random.random((15,15))
L_1=np.random.random(15)
L_2=np.random.random(15)
x=np.arange(15)
cmap = sns.color_palette("hot", 100)
sns.set(style="white")
fig, ax = plt.subplots(2, 2, sharex='col', gridspec_kw={'width_ratios':[100,5]})
ax[0,1].remove() # remove unused upper right axes
ax[0,0].plot(x+0.5,L_1,'-', marker='o',color='tab:orange')
sns.heatmap(M_1, cmap=cmap, vmax=np.max(M_1), center=np.max(M_1)/2., square=False, ax=ax[1,0], cbar_ax=ax[1,1])

Set xlim in heatmap with subplots and annotation

I would like to plot several heatmaps side by side, with annotations.
For this, I use subplots and I can plot each heatmap in its axes using the ax kwarg.
The issue is when I use xlim : it's applied to the heatmap, but not the annotation :
Here is the code :
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
values = np.random.random((7,24)) # create (7,24) shape array
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(30,10)) # create 2 columns for subplots
ax1 = sns.heatmap(values, annot=True, ax=ax1) # heatmap with annotation
ax1.set(xlim=(12,22)) # works fine with this line commented
# ax1.set_xlim(12,22)
# ax2 = sns.heatmap(values, annot=True, ax=ax2) # second heatmap
plt.show()
And it gets worse with a second heatmap, because the annotation from the second heatmap are ploted on the first heatmap.
How can I limit x axis to (12,22) while using annotation ?
matplotlib 2.2.2
seaborn 0.9.0
python 3.6.5
Why not providing the slice of interest in the first place and relabel the x-axis?
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(1234)
values = np.random.random((7,24)) # create (7,24) shape array # create (7,24) shape array ) # create (7,24) shape array
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(21,7)) # create 2 columns for subplots
#full heatmap
sns.heatmap(values, annot=True, ax=ax1)
#slice of interest
start=12
stop=22
sns.heatmap(values[:, start:stop+1], annot=True, ax=ax2, xticklabels = np.arange(start, stop+1)) # second heatmap
plt.show()
Sample output
After posting this issue on seaborn github, here is the official answer :
matplotlib text objects are not automatically clipped when they are
placed outside of the axes limits; you can turn that on by passing
annot_kws=dict(clip_on=True) to heatmap, though.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
values = np.random.random((7,24)) # create (7,24) shape array
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(30,10)) # create 2 columns for subplots
ax1 = sns.heatmap(values, annot=True, ax=ax1, annot_kws=dict(clip_on=True)) # heatmap with annotation
ax1.set(xlim=(12,22)) # works fine with this line commented
# ax1.set_xlim(12,22)
ax2 = sns.heatmap(values, annot=True, ax=ax2, annot_kws=dict(clip_on=True)) # second heatmap
ax2.set(xlim=(12,22))
plt.show()
clip_on=True will remove everithing that is outside the axe

Seaborn heatmap - colorbar label font size

How do I set the font size of the colorbar label?
ax=sns.heatmap(table, vmin=60, vmax=100, xticklabels=[4,8,16,32,64,128],yticklabels=[2,4,6,8], cmap="PuBu",linewidths=.0,
annot=True,cbar_kws={'label': 'Accuracy %'}
Unfortunately seaborn does not give access to the objects it creates. So one needs to take the detour, using the fact that the colorbar is an axes in the current figure and that it is the last one created, hence
ax = sns.heatmap(...)
cbar_axes = ax.figure.axes[-1]
For this axes, we may set the fontsize by getting the ylabel using its set_size method.
Example, setting the fontsize to 20 points:
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(0)
import seaborn as sns
data = np.random.rand(10, 12)*100
ax = sns.heatmap(data, cbar_kws={'label': 'Accuracy %'})
ax.figure.axes[-1].yaxis.label.set_size(20)
plt.show()
Note that the same can of course be achieved by via
ax = sns.heatmap(data)
ax.figure.axes[-1].set_ylabel('Accuracy %', size=20)
without the keyword argument passing.
You could also explicitly pass in the axes objects into heatmap and modify them directly:
grid_spec = {"width_ratios": (.9, .05)}
f, (ax, cbar_ax) = plt.subplots(1,2, gridspec_kw=grid_spec)
sns.heatmap(data, ax=ax, cbar_ax=cbar_ax, cbar_kws={'label': 'Accuracy %'})
cbar_ax.yaxis.label.set_size(20)

How can I make a barplot and a lineplot in the same seaborn plot with different Y axes nicely?

I have two different sets of data with a common index, and I want to represent the first one as a barplot and the second one as a lineplot in the same graph. My current approach is similar to the following.
ax = pt.a.plot(alpha = .75, kind = 'bar')
ax2 = ax.twinx()
ax2.plot(ax.get_xticks(), pt.b.values, alpha = .75, color = 'r')
And the result is similar to this
This image is really nice and almost right. My only problem is that ax.twinx() seems to create a new canvas on top of the previous one, and the white lines are clearly seen on top of the barplot.
Is there any way to plot this without including the white lines?
You can use twinx() method along with seaborn to create a seperate y-axis, one for the lineplot and the other for the barplot. To control the style of the plot (default style of seaborn is darkgrid), you can use set_style method and specify the preferred theme. If you set style=None it resets to white background without the gridlines. You can also try whitegrid. If you want to further customize the gridlines, you can do it on the axis level using the ax2.grid(False).
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
matplotlib.rc_file_defaults()
ax1 = sns.set_style(style=None, rc=None )
fig, ax1 = plt.subplots(figsize=(12,6))
sns.lineplot(data = df['y_var_1'], marker='o', sort = False, ax=ax1)
ax2 = ax1.twinx()
sns.barplot(data = df, x='x_var', y='y_var_2', alpha=0.5, ax=ax2)
You have to remove grid lines of the second axis. Add to the code ax2.grid(False). However y-ticks of the second axis will be not align to y-ticks of the first y-axis, like here:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(pd.Series(np.random.uniform(0,1,size=10)), color='g')
ax2 = ax1.twinx()
ax2.plot(pd.Series(np.random.uniform(0,17,size=10)), color='r')
ax2.grid(False)
plt.show()

Categories

Resources