I would like to plot two dfs with two different colors. For each df, I would need to add two markers. Here is what I have tried:
for stats_file in stats_files:
data = Graph(stats_file)
Graph.compute(data)
data.servers_df.plot(x="time", y="percentage", linewidth=1, kind='line')
plt.plot(data.first_measurement['time'], data.first_measurement['percentage'], 'o-', color='orange')
plt.plot(data.second_measurement['time'], data.second_measurement['percentage'], 'o-', color='green')
plt.show()
Using this piece of code, I get the servers_df plotted with markers, but on separate graphs.
How I can have both graphs in a single one to compare them better?
Thanks.
TL;DR
Your call to data.servers_df.plot() always creates a new plot, and plt.plot() plots on the latest plot that was created. The solution is to create dedicated axis for everything to plot onto.
Preface
I assumed your variables are the following
data.servers_df: Dataframe with two float columns "time" and "percentage"
data.first_measurements: A dictionary with keys "time" and `"percentage", which each are a list of floats
data.second_measurements: A dictionary with keys "time" and "percentage", which each are a list of floats
I skipped generating stat_files as you did not show what Graph() does, but just created a list of dummy data.
If data.first_measurements and data.second_measurements are also dataframes, let me know and there is an even nicer solution.
Theory - Behind the curtains
Each matplotlib plot (line, bar, etc.) lives on a matplotlib.axes.Axes element. These are like regular axes of a coordinate system. Now two things happen here:
When you use plt.plot(), there are no axes specified and thus, matplotlib looks up the current axes element (in the background), and if there is none, it will create an empty one and use it, and set is as default. The second call to plt.plot() then finds these axes and uses them.
DataFrame.plot() on the other hand, always creates a new axes element if none is given to it (possible through the ax argument)
So in your code, data.servers_df.plot() first creates an axes element behind the curtains (which is then the default), and the two following plt.plot() calls get the default axes and plot onto it - which is why you get two plots instead of one.
Solution
The following solution first creates a dedicated matplotlib.axes.Axes using plt.subplots(). This axis element is then used to draw all lines onto. Note especially the ax=ax in data.server_df.plot(). Note that I changed the display of your markers from o- to o (as we don't want to display a line (-) but only markers (o)).
Mock data can be found below
fig, ax = plt.subplots() # Here we create the axes that all data will plot onto
for i, data in enumerate(stat_files):
y_column = f'percentage_{i}' # Make the columns identifiable
data.servers_df \
.rename(columns={'percentage': y_column}) \
.plot(x='time', y=y_column, linewidth=1, kind='line', ax=ax)
ax.plot(data.first_measurement['time'], data.first_measurement['percentage'], 'o', color='orange')
ax.plot(data.second_measurement['time'], data.second_measurement['percentage'], 'o', color='green')
plt.show()
Mock data
import random
import pandas as pd
import matplotlib.pyplot as plt
# Generation of dummy data
random.seed(1)
NUMBER_OF_DATA_FILES = 2
X_LENGTH = 10
class Data:
def __init__(self):
self.servers_df = pd.DataFrame(
{
'time': range(X_LENGTH),
'percentage': [random.randint(0, 10) for _ in range(X_LENGTH)]
}
)
self.first_measurement = {
'time': self.servers_df['time'].values[:X_LENGTH // 2],
'percentage': self.servers_df['percentage'].values[:X_LENGTH // 2]
}
self.second_measurement = {
'time': self.servers_df['time'].values[X_LENGTH // 2:],
'percentage': self.servers_df['percentage'].values[X_LENGTH // 2:]
}
stat_files = [Data() for _ in range(NUMBER_OF_DATA_FILES)]
DataFrame.plot() by default returns a matplotlib.axes.Axes object. You should then plot the other two plots on this object:
for stats_file in stats_files:
data = Graph(stats_file)
Graph.compute(data)
ax = data.servers_df.plot(x="time", y="percentage", linewidth=1, kind='line')
ax.plot(data.first_measurement['time'], data.first_measurement['percentage'], 'o-', color='orange')
ax.plot(data.second_measurement['time'], data.second_measurement['percentage'], 'o-', color='green')
plt.show()
If you want to plot them one on top of the others with different colors you can do something like this:
colors = ['C0', 'C1', 'C2'] # matplotlib default color palette
# assuming that len(stats_files) = 3
# if not you need to specify as many colors as necessary
ax = plt.subplot(111)
for stats_file, c in zip(stats_files, colors):
data = Graph(stats_file)
Graph.compute(data)
data.servers_df.plot(x="time", y="percentage", linewidth=1, kind='line', ax=ax)
ax.plot(data.first_measurement['time'], data.first_measurement['percentage'], 'o-', color=c)
ax.plot(data.second_measurement['time'], data.second_measurement['percentage'], 'o-', color='green')
plt.show()
This just changes the color of the servers_df.plot. If you want to change the color of the other two you can just to the same logic: create a list of colors that you want them to take at each iteration, iterate over that list and pass the color value to the color param at each iteration.
You can create an Axes object for plotting in the first place, for example
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
df_one = pd.DataFrame({'a':np.linspace(1,10,10),'b':np.linspace(1,10,10)})
df_two = pd.DataFrame({'a':np.random.randint(0,20,10),'b':np.random.randint(0,5,10)})
dfs = [df_one,df_two]
fig,ax = plt.subplots(figsize=(8,6))
colors = ['navy','darkviolet']
markers = ['x','o']
for ind,item in enumerate(dfs):
ax.plot(item['a'],item['b'],c=colors[ind],marker=markers[ind])
as you can see, in the same ax, the two dataframes are plotted with different colors and markers.
You need to create the plot before.
Afterwards, you can explicitly refer to this plot while plotting the graphs.
df.plot(..., ax=ax) or ax.plot(x, y)
import matplotlib.pyplot as plt
(fig, ax) = plt.subplots(figsize=(20,5))
for stats_file in stats_files:
data = Graph(stats_file)
Graph.compute(data)
data.servers_df.plot(x="time", y="percentage", linewidth=1, kind='line', ax=ax)
ax.plot(data.first_measurement['time'], data.first_measurement['percentage'], 'o-', color='orange')
ax.plot(data.second_measurement['time'], data.second_measurement['percentage'], 'o-', color='green')
plt.show()
Related
I am plotting the same type of information, but for different countries, with multiple subplots with Matplotlib. That is, I have nine plots on a 3x3 grid, all with the same for lines (of course, different values per line).
However, I have not figured out how to put a single legend (since all nine subplots have the same lines) on the figure just once.
How do I do that?
There is also a nice function get_legend_handles_labels() you can call on the last axis (if you iterate over them) that would collect everything you need from label= arguments:
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center')
figlegend may be what you're looking for: matplotlib.pyplot.figlegend
An example is at Figure legend demo.
Another example:
plt.figlegend(lines, labels, loc = 'lower center', ncol=5, labelspacing=0.)
Or:
fig.legend(lines, labels, loc = (0.5, 0), ncol=5)
TL;DR
lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
fig.legend(lines, labels)
I have noticed that none of the other answers displays an image with a single legend referencing many curves in different subplots, so I have to show you one... to make you curious...
Now, if I've teased you enough, here it is the code
from numpy import linspace
import matplotlib.pyplot as plt
# each Axes has a brand new prop_cycle, so to have differently
# colored curves in different Axes, we need our own prop_cycle
# Note: we CALL the axes.prop_cycle to get an itertoools.cycle
color_cycle = plt.rcParams['axes.prop_cycle']()
# I need some curves to plot
x = linspace(0, 1, 51)
functs = [x*(1-x), x**2*(1-x),
0.25-x*(1-x), 0.25-x**2*(1-x)]
labels = ['$x-x²$', '$x²-x³$',
'$\\frac{1}{4} - (x-x²)$', '$\\frac{1}{4} - (x²-x³)$']
# the plot,
fig, (a1,a2) = plt.subplots(2)
for ax, f, l, cc in zip((a1,a1,a2,a2), functs, labels, color_cycle):
ax.plot(x, f, label=l, **cc)
ax.set_aspect(2) # superfluos, but nice
# So far, nothing special except the managed prop_cycle. Now the trick:
lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
# Finally, the legend (that maybe you'll customize differently)
fig.legend(lines, labels, loc='upper center', ncol=4)
plt.show()
If you want to stick with the official Matplotlib API, this is
perfect, otherwise see note no.1 below (there is a private
method...)
The two lines
lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
deserve an explanation, see note 2 below.
I tried the method proposed by the most up-voted and accepted answer,
# fig.legend(lines, labels, loc='upper center', ncol=4)
fig.legend(*a2.get_legend_handles_labels(),
loc='upper center', ncol=4)
and this is what I've got
Note 1
If you don't mind using a private method of the matplotlib.legend module ... it's really much much much easier
from matplotlib.legend import _get_legend_handles_labels
...
fig.legend(*_get_legend_handles_and_labels(fig.axes), ...)
Note 2
I have encapsulated the two tricky lines in a function, just four lines of code, but heavily commented
def fig_legend(fig, **kwdargs):
# Generate a sequence of tuples, each contains
# - a list of handles (lohand) and
# - a list of labels (lolbl)
tuples_lohand_lolbl = (ax.get_legend_handles_labels() for ax in fig.axes)
# E.g., a figure with two axes, ax0 with two curves, ax1 with one curve
# yields: ([ax0h0, ax0h1], [ax0l0, ax0l1]) and ([ax1h0], [ax1l0])
# The legend needs a list of handles and a list of labels,
# so our first step is to transpose our data,
# generating two tuples of lists of homogeneous stuff(tolohs), i.e.,
# we yield ([ax0h0, ax0h1], [ax1h0]) and ([ax0l0, ax0l1], [ax1l0])
tolohs = zip(*tuples_lohand_lolbl)
# Finally, we need to concatenate the individual lists in the two
# lists of lists: [ax0h0, ax0h1, ax1h0] and [ax0l0, ax0l1, ax1l0]
# a possible solution is to sum the sublists - we use unpacking
handles, labels = (sum(list_of_lists, []) for list_of_lists in tolohs)
# Call fig.legend with the keyword arguments, return the legend object
return fig.legend(handles, labels, **kwdargs)
I recognize that sum(list_of_lists, []) is a really inefficient method to flatten a list of lists, but ① I love its compactness, ② usually is a few curves in a few subplots and ③ Matplotlib and efficiency? ;-)
For the automatic positioning of a single legend in a figure with many axes, like those obtained with subplots(), the following solution works really well:
plt.legend(lines, labels, loc = 'lower center', bbox_to_anchor = (0, -0.1, 1, 1),
bbox_transform = plt.gcf().transFigure)
With bbox_to_anchor and bbox_transform=plt.gcf().transFigure, you are defining a new bounding box of the size of your figureto be a reference for loc. Using (0, -0.1, 1, 1) moves this bounding box slightly downwards to prevent the legend to be placed over other artists.
OBS: Use this solution after you use fig.set_size_inches() and before you use fig.tight_layout()
You just have to ask for the legend once, outside of your loop.
For example, in this case I have 4 subplots, with the same lines, and a single legend.
from matplotlib.pyplot import *
ficheiros = ['120318.nc', '120319.nc', '120320.nc', '120321.nc']
fig = figure()
fig.suptitle('concentration profile analysis')
for a in range(len(ficheiros)):
# dados is here defined
level = dados.variables['level'][:]
ax = fig.add_subplot(2,2,a+1)
xticks(range(8), ['0h','3h','6h','9h','12h','15h','18h','21h'])
ax.set_xlabel('time (hours)')
ax.set_ylabel('CONC ($\mu g. m^{-3}$)')
for index in range(len(level)):
conc = dados.variables['CONC'][4:12,index] * 1e9
ax.plot(conc,label=str(level[index])+'m')
dados.close()
ax.legend(bbox_to_anchor=(1.05, 0), loc='lower left', borderaxespad=0.)
# it will place the legend on the outer right-hand side of the last axes
show()
If you are using subplots with bar charts, with a different colour for each bar, it may be faster to create the artefacts yourself using mpatches.
Say you have four bars with different colours as r, m, c, and k, you can set the legend as follows:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
labels = ['Red Bar', 'Magenta Bar', 'Cyan Bar', 'Black Bar']
#####################################
# Insert code for the subplots here #
#####################################
# Now, create an artist for each color
red_patch = mpatches.Patch(facecolor='r', edgecolor='#000000') # This will create a red bar with black borders, you can leave out edgecolor if you do not want the borders
black_patch = mpatches.Patch(facecolor='k', edgecolor='#000000')
magenta_patch = mpatches.Patch(facecolor='m', edgecolor='#000000')
cyan_patch = mpatches.Patch(facecolor='c', edgecolor='#000000')
fig.legend(handles = [red_patch, magenta_patch, cyan_patch, black_patch], labels=labels,
loc="center right",
borderaxespad=0.1)
plt.subplots_adjust(right=0.85) # Adjust the subplot to the right for the legend
To build on top of gboffi's and Ben Usman's answer:
In a situation where one has different lines in different subplots with the same color and label, one can do something along the lines of:
labels_handles = {
label: handle for ax in fig.axes for handle, label in zip(*ax.get_legend_handles_labels())
}
fig.legend(
labels_handles.values(),
labels_handles.keys(),
loc = "upper center",
bbox_to_anchor = (0.5, 0),
bbox_transform = plt.gcf().transFigure,
)
Using Matplotlib 2.2.2, this can be achieved using the gridspec feature.
In the example below, the aim is to have four subplots arranged in a 2x2 fashion with the legend shown at the bottom. A 'faux' axis is created at the bottom to place the legend in a fixed spot. The 'faux' axis is then turned off so only the legend shows. Result:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
# Gridspec demo
fig = plt.figure()
fig.set_size_inches(8, 9)
fig.set_dpi(100)
rows = 17 # The larger the number here, the smaller the spacing around the legend
start1 = 0
end1 = int((rows-1)/2)
start2 = end1
end2 = int(rows-1)
gspec = gridspec.GridSpec(ncols=4, nrows=rows)
axes = []
axes.append(fig.add_subplot(gspec[start1:end1, 0:2]))
axes.append(fig.add_subplot(gspec[start2:end2, 0:2]))
axes.append(fig.add_subplot(gspec[start1:end1, 2:4]))
axes.append(fig.add_subplot(gspec[start2:end2, 2:4]))
axes.append(fig.add_subplot(gspec[end2, 0:4]))
line, = axes[0].plot([0, 1], [0, 1], 'b') # Add some data
axes[-1].legend((line,), ('Test',), loc='center') # Create legend on bottommost axis
axes[-1].set_axis_off() # Don't show the bottom-most axis
fig.tight_layout()
plt.show()
This answer is a complement to user707650's answer on the legend position.
My first try on user707650's solution failed due to overlaps of the legend and the subplot's title.
In fact, the overlaps are caused by fig.tight_layout(), which changes the subplots' layout without considering the figure legend. However, fig.tight_layout() is necessary.
In order to avoid the overlaps, we can tell fig.tight_layout() to leave spaces for the figure's legend by fig.tight_layout(rect=(0,0,1,0.9)).
Description of tight_layout() parameters.
All of the previous answers are way over my head, at this state of my coding journey, so I just added another Matplotlib aspect called patches:
import matplotlib.patches as mpatches
first_leg = mpatches.Patch(color='red', label='1st plot')
second_leg = mpatches.Patch(color='blue', label='2nd plot')
thrid_leg = mpatches.Patch(color='green', label='3rd plot')
plt.legend(handles=[first_leg ,second_leg ,thrid_leg ])
The patches aspect put all the data i needed on my final plot (it was a line plot that combined three different line plots all in the same cell in Jupyter Notebook).
Result
(I changed the names form what I named my own legend.)
I try to figure out how to create scatter plot in matplotlib with two different y-axis values.
Now i have one and need to add second with index column values on y.
points1 = plt.scatter(r3_load["TimeUTC"], r3_load["r3_load_MW"],
c=r3_load["r3_load_MW"], s=50, cmap="rainbow", alpha=1) #set style options
plt.rcParams['figure.figsize'] = [20,10]
#plt.colorbar(points)
plt.title("timeUTC vs Load")
#plt.xlim(0, 400)
#plt.ylim(0, 300)
plt.xlabel('timeUTC')
plt.ylabel('Load_MW')
cbar = plt.colorbar(points1)
cbar.set_label('Load')
Result i expect is like this:
So second scatter set should be for TimeUTC vs index. Colors are not the subject;) also in excel y-axes are different sites, but doesnt matter.
Appriciate your help! Thanks, Paulina
Continuing after the suggestions in the comments.
There are two ways of using matplotlib.
Via the matplotlib.pyplot interface, like you were doing in your original code snippet with .plt
The object-oriented way. This is the suggested way to use matplotlib, especially when you need more customisation like in your case. In your code, ax1 is an Axes instance.
From an Axes instance, you can plot your data using the Axes.plot and Axes.scatter methods, very similar to what you did through the pyplot interface. This means, you can write a Axes.scatter call instead of .plot and use the same parameters as in your original code:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.scatter(r3_load["TimeUTC"], r3_load["r3_load_MW"],
c=r3_load["r3_load_MW"], s=50, cmap="rainbow", alpha=1)
ax2.plot(r3_dda249["TimeUTC"], r3_dda249.index, c='b', linestyle='-')
ax1.set_xlabel('TimeUTC')
ax1.set_ylabel('r3_load_MW', color='g')
ax2.set_ylabel('index', color='b')
plt.show()
I have a list of dataframes named merged_dfs that I am looping through to get the correlation and plot subplots of heatmap correlation matrix using seaborn.
I want to customize the colorbar tick labels, but I am having trouble figuring out how to do it with my example.
Currently, my colorbar scale values from top to bottom are
[1,0.5,0,-0.5,-1]
I want to keep these values, but change the tick labels to be
[1,0.5,0,0.5,1]
for my diverging color bar.
Here is the code and my attempt:
fig, ax = plt.subplots(nrows=6, ncols=2, figsize=(20,20))
for i, (title,merging) in enumerate (zip(new_name_data,merged_dfs)):
graph = merging.corr()
colormap = sns.diverging_palette(250, 250, as_cmap=True)
a = sns.heatmap(graph.abs(), cmap=colormap, vmin=-1,vmax=1,center=0,annot = graph, ax=ax.flat[i])
cbar = fig.colorbar(a)
cbar.set_ticklabels(["1","0.5","0","0.5","1"])
fig.delaxes(ax[5,1])
plt.show()
plt.close()
I keep getting this error:
AttributeError: 'AxesSubplot' object has no attribute 'get_array'
Several things are going wrong:
fig.colorbar(...) would create a new colorbar, by default appended to the last subplot that was created.
sns.heatmap returns an ax (indicates a subplot). This is very different to matplotlib functions, e.g. plt.imshow(), which would return the graphical element that was plotted.
You can suppress the heatmap's colorbar (cbar=False), and then create it newly with the parameters you want.
fig.colorbar(...) needs a parameter ax=... when the figure contains more than one subplot.
Instead of creating a new colorbar, you can add the colorbar parameters to sns.heatmap via cbar_kws=.... The colorbar itself can be found via ax.collections[0].colobar. (ax.collections[0] is where matplotlib stored the graphical object that contains the heatmap.)
Using an index is strongly discouraged when working with Python. It's usually more readable, easier to maintain and less error-prone to include everything into the zip command.
As now your vmin now is -1, taking the absolute value for the coloring seems to be a mistake.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
merged_dfs = [pd.DataFrame(data=np.random.rand(5, 7), columns=[*'ABCDEFG']) for _ in range(5)]
new_name_data = [f'Dataset {i + 1}' for i in range(len(merged_dfs))]
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 7))
for title, merging, ax in zip(new_name_data, merged_dfs, axes.flat):
graph = merging.corr()
colormap = sns.diverging_palette(250, 250, as_cmap=True)
sns.heatmap(graph, cmap=colormap, vmin=-1, vmax=1, center=0, annot=True, ax=ax, cbar_kws={'ticks': ticks})
ax.collections[0].colorbar.set_ticklabels([abs(t) for t in ticks])
fig.delaxes(axes.flat[-1])
fig.tight_layout()
plt.show()
I have written the following minimal Python code in order to plot various functions of x on the same X-axis.
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from cycler import cycler
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
xlabel='$X$'; ylabel='$Y$'
### Set tick features
plt.tick_params(axis='both',which='major',width=2,length=10,labelsize=18)
plt.tick_params(axis='both',which='minor',width=2,length=5)
#plt.set_axis_bgcolor('grey') # Doesn't work if I uncomment!
lines = ["-","--","-.",":"]
Nlayer=4
f, axarr = plt.subplots(Nlayer, sharex=True)
for a in range(1,Nlayer+1):
X = np.linspace(0,10,100)
Y = X**a
index = a-1 + np.int((a-1)/Nlayer)
axarr[a-1].plot(X, Y, linewidth=2.0+index, color=cycle[a], linestyle = lines[index], label='Layer = {}'.format(a))
axarr[a-1].legend(loc='upper right', prop={'size':6})
#plt.legend()
# Axes labels
plt.xlabel(xlabel, fontsize=20)
plt.ylabel(ylabel, fontsize=20)
plt.show()
However, the plots don't join together on the X-axis and I failed to get a common Y-axis label. It actually labels for the last plot (see attached figure). I also get a blank plot additionally which I couldn't get rid of.
I am using Python3.
The following code will produce the expected output :
without blank plot which was created because of the two plt.tick_params calls before creating the actual fig
with the gridspec_kw argument of subplots that allows you to control the space between rows and cols of subplots environment in order to join the different layer plots
with unique and centered common ylabel using fig.text with relative positioning and rotation argument (same thing is done to xlabel to get an homogeneous final result). One may note that, it can also be done by repositioning the ylabel with ax.yaxis.set_label_coords() after an usual call like ax.set_ylabel().
import numpy as np
import matplotlib.pyplot as plt
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
xlabel='$X$'; ylabel='$Y$'
lines = ["-","--","-.",":"]
Nlayer = 4
fig, axarr = plt.subplots(Nlayer, sharex='col',gridspec_kw={'hspace': 0, 'wspace': 0})
X = np.linspace(0,10,100)
for i,ax in enumerate(axarr):
Y = X**(i+1)
ax.plot(X, Y, linewidth=2.0+i, color=cycle[i], linestyle = lines[i], label='Layer = {}'.format(i+1))
ax.legend(loc='upper right', prop={'size':6})
with axes labels, first option :
fig.text(0.5, 0.01, xlabel, va='center')
fig.text(0.01, 0.5, ylabel, va='center', rotation='vertical')
or alternatively :
# ax is here, the one of the last Nlayer plotted, i.e. Nlayer=4
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# change y positioning to be in the horizontal center of all Nlayer, i.e. dynamically Nlayer/2
ax.yaxis.set_label_coords(-0.1,Nlayer/2)
which gives :
I also simplified your for loop by using enumerate to have an automatic counter i when looping over axarr.
I would like to know why depending on how you call plt.plot() on an ax this ax may or may not be able to be modified downstream. Is this a bug in matplotlib or am I misunderstanding something? An example which illustrates the issue is shown below.
I am attempting to modify the legend location downstream of a plotting function call, similar to as discussed here. For whatever reason this seams to be dependent on how I call ax.plot. Here are two examples illustrating the issue
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def myplot(df, ax, loop_cols):
if loop_cols:
for col in df.columns:
ax.plot(df.loc[:, col])
else:
ax.plot(df)
ax.legend(df.columns)
return ax
This just amounts to calling ax.plot() repeatedly on the pd.Series, or calling it once on the pd.DataFrame. However depending on how this is called, it results in the inability to later modify the legend, as shown below.
df = pd.DataFrame(np.random.randn(100, 3)).cumsum()
fig = plt.figure()
ax = fig.gca()
myplot(df, ax, loop_cols=True)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
Chart legend properly set to right side
fig = plt.figure()
ax = fig.gca()
myplot(df, ax, loop_cols=False)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
Chart legend not set to appropriate location
This is with matplotlib version 2.1.0
I'm not why this behaviour is actually happening.... but you are calling ax.legend twice, once in the function, and once outside it. Altering your function such that the call to ax.legend() inside the function contains all the information solves the problem. This includes passing the legend handles and labels. In the example below I used ax.lines to get the Line2D objects, however if your plots are more complicated you may have to get the list from the call to plot using lines = ax.plot()
If the properties of the legend change then you could modify the function to accept paramters that are passed to ax.legend.
def myplot(df, ax, loop_cols):
if loop_cols:
for col in df.columns:
ax.plot(df.loc[:, col])
else:
ax.plot(df)
#ax.legend(df.columns) # modify this line
ax.legend(ax.lines, df.columns, loc='center left', bbox_to_anchor=(1, 0.5))
return ax
fig, (ax,ax2) = plt.subplots(1,2,figsize=(6,4))
df = pd.DataFrame(np.random.randn(100, 3)).cumsum()
myplot(df, ax, loop_cols=True)
ax.set_title("loop_cols=True")
#ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) # No need for this
myplot(df, ax2, loop_cols=False)
ax2.set_title("loop_cols=False")
#ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5)) # No need for this
plt.subplots_adjust(left=0.08,right=0.88,wspace=0.55)
plt.show()
Legends are not being modified in the above examples, new legends are being created in both cases. The issue relates to how legend() behaves differently depending on if there are labels on the matplotlib.lines.Line2Ds. The relevant section from the docs is
1. Automatic detection of elements to be shown in the legend
The elements to be added to the legend are automatically determined,
when you do not pass in any extra arguments.
In this case, the labels are taken from the artist. You can specify
them either at artist creation or by calling the
:meth:~.Artist.set_label method on the artist::
line, = ax.plot([1, 2, 3], label='Inline label')
ax.legend()
or::
line.set_label('Label via method')
line, = ax.plot([1, 2, 3])
ax.legend()
Specific lines can be excluded from the automatic legend element
selection by defining a label starting with an underscore. This is
default for all artists, so calling Axes.legend without any
arguments and without setting the labels manually will result in no
legend being drawn.
2. Labeling existing plot elements
To make a legend for lines which already exist on the axes (via plot
for instance), simply call this function with an iterable of strings,
one for each legend item. For example::
ax.plot([1, 2, 3])
ax.legend(['A simple line'])
Note: This way of using is discouraged, because the relation between
plot elements and labels is only implicit by their order and can
easily be mixed up.
In the first case, no labels are set
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame({'a': [1, 5, 3], 'b': [1, 3, -4]})
fig, axes = plt.subplots(1)
lines = axes.plot(df)
print(lines[0].get_label())
print(lines[1].get_label())
_line0
_line1
So calling legend() the first time with labels falls under case 2.. When legend is called the second time without labels it falls under case 1. As you can see, the Legend instances are different, and the second one rightly complains there are No handles with labels found to put in legend.
l1 = axes.legend(['a', 'b'])
print(repr(l1))
<matplotlib.legend.Legend object at 0x7f05638b7748>
l2 = axes.legend(loc='upper left')
No handles with labels found to put in legend.
print(repr(l2))
<matplotlib.legend.Legend object at 0x7f05638004e0>
In the second case, the lines are properly labelled and therefore the second call to legend() properly infers the labels.
s1 = pd.Series([1, 2, 3], name='a')
s2 = pd.Series([1, 5, 2], name='b')
fig, axes = plt.subplots(1)
line1 = axes.plot(s1)
line2 = axes.plot(s2)
print(line1[0].get_label())
print(line2[0].get_label())
a
b