Is it possible to run seaborn.clustermap on a previously obtained ClusterGrid object?
For example I user clustermap to obtain g in the following example:
import seaborn as ns
data = sns.load_dataset("iris")
species = iris.pop("species")
g = sns.clustermap(
data,
cmap="mako",
col_cluster=False,
yticklabels=False, figsize=(5, 10),
method='ward',
metric="euclidean"
)
I would like to try different visualization options like different colormaps, figure sizes, how it looks with and without labels etc.
With the iris dataset everything is really fast, but I have a way larger dataset and the clustering part takes a lot of time.
Can I use g to show the heatmap and dendrogram using different options?
the object returned by clustermap is of type ClusterGrid. That object is not really documented in seaborn, however, it is essentially just a container for a few Axes objects. Depending on the kind of manipulations you want to make, you may simply need to access the relevant Axes object or the figure itself:
# change the figure size after the fact
g.fig.set_size_inches((4,4))
# remove the labels of the heatmap
g.ax_heatmap.set_xticklabels([])
The colormap thing is a little more difficult to access. clustermap uses matplotlib pcolormesh under the hood. This function returns a collection object (QuadMesh), which is store in the list of collections of the main axes (g.ax_heatmap.collections). Since, AFAIK, seaborn doesn't plot anything else on that axes, We can get the QuadMesh object by its index [0], and then we can use any function applicable to that object.
# change the colormap used
g.ax_heatmap.collections[0].set_cmap('seismic')
Related
I strongly prefer using matplotlib in OOP style:
f, axarr = plt.subplots(2, sharex=True)
axarr[0].plot(...)
axarr[1].plot(...)
This makes it easier to keep track of multiple figures and subplots.
Question: How to use seaborn this way? Or, how to change this example to OOP style? How to tell seaborn plotting functions like lmplot which Figure or Axes it plots to?
It depends a bit on which seaborn function you are using.
The plotting functions in seaborn are broadly divided into two classes
"Axes-level" functions, including regplot, boxplot, kdeplot, and many others
"Figure-level" functions, including relplot, catplot, displot, pairplot, jointplot and one or two others
The first group is identified by taking an explicit ax argument and returning an Axes object. As this suggests, you can use them in an "object oriented" style by passing your Axes to them:
f, (ax1, ax2) = plt.subplots(2)
sns.regplot(x, y, ax=ax1)
sns.kdeplot(x, ax=ax2)
Axes-level functions will only draw onto an Axes and won't otherwise mess with the figure, so they can coexist perfectly happily in an object-oriented matplotlib script.
The second group of functions (Figure-level) are distinguished by the fact that the resulting plot can potentially include several Axes which are always organized in a "meaningful" way. That means that the functions need to have total control over the figure, so it isn't possible to plot, say, an lmplot onto one that already exists. Calling the function always initializes a figure and sets it up for the specific plot it's drawing.
However, once you've called lmplot, it will return an object of the type FacetGrid. This object has some methods for operating on the resulting plot that know a bit about the structure of the plot. It also exposes the underlying figure and array of axes at the FacetGrid.fig and FacetGrid.axes arguments. The jointplot function is very similar, but it uses a JointGrid object. So you can still use these functions in an object-oriented context, but all of your customization has to come after you've called the function.
When we create a new figure:
import matplotlib.pyplot as plt
fig = plt.figure()
we can see in the console an output like:
<Figure size 432x288 with 0 Axes>
so when we use add_axes:
add_axes(rect, projection=None, polar=False, **kwargs)
are we actually defining the x,y axes that encompass the "box" that will bound the figure (the axes in a more mathematical sense) and nothing more? or in fact this line of code is creating an empty figure with the desired dimensions in which any data we add later will be fitted in? (or None of the above maybe?)
That questioning left me wondering how can I physically understand what the axes are for matplotlib.
Thanks for helping.
From the point of view of the viewer, the axes is the box which will contain the data and which (usually) has an x-axis and y-axis.
From the programmatical point of view, the axes is an object, which stores several other objects like XAxis, YAxis and provides methods to create plots. Importantly, it has a transformation stored, which allows to draw the data points in pixel space.
I'm not really new to matplotlib and I'm deeply ashamed to admit I have always used it as a tool for getting a solution as quick and easy as possible. So I know how to get basic plots, subplots and stuff and have quite a few code which gets reused from time to time...but I have no "deep(er) knowledge" of matplotlib.
Recently I thought I should change this and work myself through some tutorials. However, I am still confused about matplotlibs plt, fig(ure) and ax(arr). What is really the difference?
In most cases, for some "quick'n'dirty' plotting I see people using just pyplot as plt and directly plot with plt.plot. Since I am having multiple stuff to plot quite often, I frequently use f, axarr = plt.subplots()...but most times you see only code putting data into the axarr and ignoring the figure f.
So, my question is: what is a clean way to work with matplotlib? When to use plt only, what is or what should a figure be used for? Should subplots just containing data? Or is it valid and good practice to everything like styling, clearing a plot, ..., inside of subplots?
I hope this is not to wide-ranging. Basically I am asking for some advice for the true purposes of plt <-> fig <-> ax(arr) (and when/how to use them properly).
Tutorials would also be welcome. The matplotlib documentation is rather confusing to me. When one searches something really specific, like rescaling a legend, different plot markers and colors and so on the official documentation is really precise but rather general information is not that good in my opinion. Too much different examples, no real explanations of the purposes...looks more or less like a big listing of all possible API methods and arguments.
pyplot is the 'scripting' level API in matplotlib (its highest level API to do a lot with matplotlib). It allows you to use matplotlib using a procedural interface in a similar way as you can do it with Matlab. pyplot has a notion of 'current figure' and 'current axes' that all the functions delegate to (#tacaswell dixit). So, when you use the functions available on the module pyplot you are plotting to the 'current figure' and 'current axes'.
If you want 'fine-grain' control of where/what your are plotting then you should use an object oriented API using instances of Figure and Axes.
Functions available in pyplot have an equivalent method in the Axes.
From the repo anatomy of matplotlib:
The Figure is the top-level container in this hierarchy. It is the overall window/page that everything is drawn on. You can have multiple independent figures and Figures can contain multiple Axes.
But...
Most plotting occurs on an Axes. The axes is effectively the area that we plot data on and any ticks/labels/etc associated with it. Usually we'll set up an Axes with a call to subplot (which places Axes on a regular grid), so in most cases, Axes and Subplot are synonymous.
Each Axes has an XAxis and a YAxis. These contain the ticks, tick locations, labels, etc.
If you want to know the anatomy of a plot you can visit this link.
I think that this tutorial explains well the basic notions of the object hierarchy of matplotlib like Figure and Axes, as well as the notion of current figure and current Axes.
If you want a quick answer: There is the Figure object which is the container that wraps multiple Axes(which is different from axis) which also contains smaller objects like legends, line, tick marks ... as shown in this image taken from matplotlib documentation
So when we do
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> type(fig)
<class 'matplotlib.figure.Figure'>
>>> type(ax)
<class 'matplotlib.axes._subplots.AxesSubplot'>
We have created a Figure object and an Axes object that is contained in that figure.
pyplot is matlab like API for those who are familiar with matlab and want to make quick and dirty plots
figure is object-oriented API for those who doesn't care about matlab style plotting
So you can use either one but perhaps not both together.
Given a matplotlib.axes_subplots.AexesSubplot object how do I tell what type of plot it contains? Is there a matplotlib feature that will determine this for me? for example...
I commonly plot data with pandas
import pandas as pd
df = pd.DataFrame({'y':range(10)})
line_ax = df.plot()
or
bar_ax = df.plot(kind='bar')
or
barh_ax = df.plot(kind='barh')
The matplotlib axes does not care about which plot it contains and it does not even know about it.
The question would also be how to distinguish "kinds" of plots. What kind of plot is in an axes which contains 2 bars, several markers, 2 lines and 3 arrows?
The kind argument to pandas plot function is simply a flag by which pandas decides which plotting function to call. This is independent of the axes and you may of course also have a plot produced by kind='bar' and kind='scatter' in the same axes.
So the answer is: No there is no general way to determine the kind of plot in an axes, mainly due to the fact that there is no such thing as a "kind of plot".
Of course, depending on what you'd need this type of information for, there are probably alternative ways to accomplish what you need.
I'd like to create two independent matplotlib plots within a python script, and potentially jump back and forth between them as I add lines, annotations, etc. to the various plots (for example, perhaps I call a function which adds lines to both plots, and then another function which adds the annotations).
I expect that by working off matplotlib examples I'd be able to figure out some solution that works, but I'd like to know what the preferred and cleanest way of doing this is. I tend to get confused about when I should be doing things like
fig,ax=plt.subplots()
and when I should be doing things like:
fig=plt.figure()
Furthermore, how should I be switching back and forth between plots. If I did something like
fig1,ax1=plt.subplots()
fig2,ax2=plt.subplots()
can I then just refer to these plots by doing something like:
ax1.plt.plot([some stuff])
ax2.plt.plot([otherstuff]
? I ask this because often in the matplotlib examples they don't refer to the plot like this after calling plt.subplot() but instead call commands like
plt.plot([stuff])
where presumably it doesn't matter that they didn't specify ax1 or ax2 because there's only one plot in the example. At the end I'd like to save both plots to file using something like
plt.savefig(....)
although I need, again, to be able to refer to both plots independently. So what's the proper way of implementing this?
If you want to be able to write code that clearly applies commands to distinct axes, you want to use the object oriented interface.
Actually, both of your first two examples are using this interface. The differences is that plt.subplots() will create both a figure object and a grid of axes, while plt.figure() just creates the figure.
The figure object has methods to create axes within it. So, these two blocks of code are equivalent:
fig, ax = plt.subplots()
and
fig = plt.figure()
ax = fig.add_subplot(111)
Generally, the latter approach is only going to be more useful when you want multiple axes within the figure that don't follow a regular grid. So, you could do:
fig = plt.figure()
ax = fig.add_axes([.1, .1, .2, .8])
which will add a tall axes object on the left side of the figure.
Next, how do you get multiple axes to plot on?
The subplots function takes two positional arguments specifying the number of rows and columns in the grid (these default to (1, 1). So if you want two axes side by side, you would do
fig, axes = plt.subplots(1, 2)
Now axes is going to be a (1, 2) object array that is filled with Axes objects. It's often more convenient, for a small grid, to use Python's tuple unpacking and get direct references to the objects:
fig, (ax1, ax2) = plt.subplots(1, 2)
Now, what do you do with these objects, and what's the relationship between them and the MATLAB-style procedure interface?
Most functions in the pyplot namespace also exist as methods on either Figure or Axes objects. Matplotlib (and MATLAB) has a concept of the "current" figure and axes. When you call a function like plt.plot, it draws on the current axis (usually the most recently created one). When you call a function like plt.savefig, it saves the current figure.
For simple task, this is a bit more direct and usually easier than using the object-oriented interface. However, when you start making more complex plots, e.g. a grid of axes where each grid has multiple layers (maybe a scatterplot and a regression line), being able to structure the code around what you are doing rather than where you are doing has substantial advantages. Generally, plotting code that is written in an object-oriented fashion will scale much better than code written in a procedural fashion.