I'm currently evaluating different python plotting libraries. Right now I'm trying matplotlib and I'm quite disappointed with the performance. The following example is modified from SciPy examples and gives me only ~ 8 frames per second!
Any ways of speeding this up or should I pick a different plotting library?
from pylab import *
import time
ion()
fig = figure()
ax1 = fig.add_subplot(611)
ax2 = fig.add_subplot(612)
ax3 = fig.add_subplot(613)
ax4 = fig.add_subplot(614)
ax5 = fig.add_subplot(615)
ax6 = fig.add_subplot(616)
x = arange(0,2*pi,0.01)
y = sin(x)
line1, = ax1.plot(x, y, 'r-')
line2, = ax2.plot(x, y, 'g-')
line3, = ax3.plot(x, y, 'y-')
line4, = ax4.plot(x, y, 'm-')
line5, = ax5.plot(x, y, 'k-')
line6, = ax6.plot(x, y, 'p-')
# turn off interactive plotting - speeds things up by 1 Frame / second
plt.ioff()
tstart = time.time() # for profiling
for i in arange(1, 200):
line1.set_ydata(sin(x+i/10.0)) # update the data
line2.set_ydata(sin(2*x+i/10.0))
line3.set_ydata(sin(3*x+i/10.0))
line4.set_ydata(sin(4*x+i/10.0))
line5.set_ydata(sin(5*x+i/10.0))
line6.set_ydata(sin(6*x+i/10.0))
draw() # redraw the canvas
print 'FPS:' , 200/(time.time()-tstart)
First off, (though this won't change the performance at all) consider cleaning up your code, similar to this:
import matplotlib.pyplot as plt
import numpy as np
import time
x = np.arange(0, 2*np.pi, 0.01)
y = np.sin(x)
fig, axes = plt.subplots(nrows=6)
styles = ['r-', 'g-', 'y-', 'm-', 'k-', 'c-']
lines = [ax.plot(x, y, style)[0] for ax, style in zip(axes, styles)]
fig.show()
tstart = time.time()
for i in xrange(1, 20):
for j, line in enumerate(lines, start=1):
line.set_ydata(np.sin(j*x + i/10.0))
fig.canvas.draw()
print 'FPS:' , 20/(time.time()-tstart)
With the above example, I get around 10fps.
Just a quick note, depending on your exact use case, matplotlib may not be a great choice. It's oriented towards publication-quality figures, not real-time display.
However, there are a lot of things you can do to speed this example up.
There are two main reasons why this is as slow as it is.
1) Calling fig.canvas.draw() redraws everything. It's your bottleneck. In your case, you don't need to re-draw things like the axes boundaries, tick labels, etc.
2) In your case, there are a lot of subplots with a lot of tick labels. These take a long time to draw.
Both these can be fixed by using blitting.
To do blitting efficiently, you'll have to use backend-specific code. In practice, if you're really worried about smooth animations, you're usually embedding matplotlib plots in some sort of gui toolkit, anyway, so this isn't much of an issue.
However, without knowing a bit more about what you're doing, I can't help you there.
Nonetheless, there is a gui-neutral way of doing it that is still reasonably fast.
import matplotlib.pyplot as plt
import numpy as np
import time
x = np.arange(0, 2*np.pi, 0.1)
y = np.sin(x)
fig, axes = plt.subplots(nrows=6)
fig.show()
# We need to draw the canvas before we start animating...
fig.canvas.draw()
styles = ['r-', 'g-', 'y-', 'm-', 'k-', 'c-']
def plot(ax, style):
return ax.plot(x, y, style, animated=True)[0]
lines = [plot(ax, style) for ax, style in zip(axes, styles)]
# Let's capture the background of the figure
backgrounds = [fig.canvas.copy_from_bbox(ax.bbox) for ax in axes]
tstart = time.time()
for i in xrange(1, 2000):
items = enumerate(zip(lines, axes, backgrounds), start=1)
for j, (line, ax, background) in items:
fig.canvas.restore_region(background)
line.set_ydata(np.sin(j*x + i/10.0))
ax.draw_artist(line)
fig.canvas.blit(ax.bbox)
print 'FPS:' , 2000/(time.time()-tstart)
This gives me ~200fps.
To make this a bit more convenient, there's an animations module in recent versions of matplotlib.
As an example:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
x = np.arange(0, 2*np.pi, 0.1)
y = np.sin(x)
fig, axes = plt.subplots(nrows=6)
styles = ['r-', 'g-', 'y-', 'm-', 'k-', 'c-']
def plot(ax, style):
return ax.plot(x, y, style, animated=True)[0]
lines = [plot(ax, style) for ax, style in zip(axes, styles)]
def animate(i):
for j, line in enumerate(lines, start=1):
line.set_ydata(np.sin(j*x + i/10.0))
return lines
# We'd normally specify a reasonable "interval" here...
ani = animation.FuncAnimation(fig, animate, xrange(1, 200),
interval=0, blit=True)
plt.show()
Matplotlib makes great publication-quality graphics, but is not very well optimized for speed.
There are a variety of python plotting packages that are designed with speed in mind:
http://vispy.org
http://pyqtgraph.org/
http://docs.enthought.com/chaco/
http://pyqwt.sourceforge.net/
[ edit: pyqwt is no longer maintained; the previous maintainer is recommending pyqtgraph ]
http://code.google.com/p/guiqwt/
To start, Joe Kington's answer provides very good advice using a gui-neutral approach, and you should definitely take his advice (especially about Blitting) and put it into practice. More info on this approach, read the Matplotlib Cookbook
However, the non-GUI-neutral (GUI-biased?) approach is key to speeding up the plotting. In other words, the backend is extremely important to plot speed.
Put these two lines before you import anything else from matplotlib:
import matplotlib
matplotlib.use('GTKAgg')
Of course, there are various options to use instead of GTKAgg, but according to the cookbook mentioned before, this was the fastest. See the link about backends for more options.
For the first solution proposed by Joe Kington ( .copy_from_bbox & .draw_artist & canvas.blit), I had to capture the backgrounds after the fig.canvas.draw() line, otherwise the background had no effect and I got the same result as you mentioned. If you put it after the fig.show() it still does not work as proposed by Michael Browne.
So just put the background line after the canvas.draw():
[...]
fig.show()
# We need to draw the canvas before we start animating...
fig.canvas.draw()
# Let's capture the background of the figure
backgrounds = [fig.canvas.copy_from_bbox(ax.bbox) for ax in axes]
This may not apply to many of you, but I'm usually operating my computers under Linux, so by default I save my matplotlib plots as PNG and SVG. This works fine under Linux but is unbearably slow on my Windows 7 installations [MiKTeX under Python(x,y) or Anaconda], so I've taken to adding this code, and things work fine over there again:
import platform # Don't save as SVG if running under Windows.
#
# Plot code goes here.
#
fig.savefig('figure_name.png', dpi = 200)
if platform.system() != 'Windows':
# In my installations of Windows 7, it takes an inordinate amount of time to save
# graphs as .svg files, so on that platform I've disabled the call that does so.
# The first run of a script is still a little slow while everything is loaded in,
# but execution times of subsequent runs are improved immensely.
fig.savefig('figure_name.svg')
Related
I am new to python and trying to do what have been doing in MATLAB for so long. My current challenge is to dynamically update a plot without drawing a new figure in a for or while loop. I am aware there are similar questions and answers but most of them are too complicated and I believe it should be easier.
I got the example from here
https://pythonspot.com/matplotlib-update-plot/
But I can't see the figure, no error, no nothing. I added two lines just to see if I can see the static plot and I can.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10*np.pi, 100)
y = np.sin(x)
# This is just a test just to see if I can see the plot window
plt.plot(x, y)
plt.show()
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111)
line1, = ax.plot(x, y, 'b-')
for phase in np.linspace(0, 10*np.pi, 100):
line1.set_ydata(np.sin(0.5 * x + phase))
fig.canvas.draw()
Any idea why I can't see the dynamic plot?
Thank you
Erdem
try to add plt.pause(0.0001) inside the loop after plt.show(block=False), and a final plt.show() outside the loop. This should work fine with plt.ion(); ref to some older answers Plot one figure at a time without closing old figure (matplotlib)
I have a python program that plots the data from a file as a contour plot for each line in that text file. Currently, I have 3 separate contour plots in my interface. It does not matter if I read the data from a file or I load it to the memory before executing the script I can only get ~6fps from the contour plots.
I also tried using just one contour and the rest normal plots but the speed only increased to 7fps. I don't believe that it is so computationally taxing to draw few lines. Is there a way to make it substantially faster? Ideally, it would be nice to get at least 30fps.
The way I draw the contour is that for each line of my data I remove the previous one:
for coll in my_contour[0].collections:
coll.remove()
and add a new one
my_contour[0] = ax[0].contour(x, y, my_func, [0])
At the beginning of the code, I have plt.ion() to update the plots as I add them.
Any help would be appreciated.
Thanks
Here is an example on how to use a contour plot in an animation. It uses matplotlib.animation.FuncAnimation which makes it easy to turn blitting on and off.
With blit=True it runs at ~64 fps on my machine, without blitting ~55 fps. Note that the interval must of course allow for the fast animation; setting it to interval=10 (milliseconds) would allow for up to 100 fps, but the drawing time limits it to something slower than that.
import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np
import time
x= np.linspace(0,3*np.pi)
X,Y = np.meshgrid(x,x)
f = lambda x,y, alpha, beta :(np.sin(X+alpha)+np.sin(Y*(1+np.sin(beta)*.4)+alpha))**2
alpha=np.linspace(0, 2*np.pi, num=34)
levels= 10
cmap=plt.cm.magma
fig, ax=plt.subplots()
props = dict(boxstyle='round', facecolor='wheat')
timelabel = ax.text(0.9,0.9, "", transform=ax.transAxes, ha="right", bbox=props)
t = np.ones(10)*time.time()
p = [ax.contour(X,Y,f(X,Y,0,0), levels, cmap=cmap ) ]
def update(i):
for tp in p[0].collections:
tp.remove()
p[0] = ax.contour(X,Y,f(X,Y,alpha[i],alpha[i]), levels, cmap= cmap)
t[1:] = t[0:-1]
t[0] = time.time()
timelabel.set_text("{:.3f} fps".format(-1./np.diff(t).mean()))
return p[0].collections+[timelabel]
ani = matplotlib.animation.FuncAnimation(fig, update, frames=len(alpha),
interval=10, blit=True, repeat=True)
plt.show()
Note that in the animated gif above a slower frame rate is shown, since the process of saving the images takes a little longer.
I'm hoping to find a way to optimise the following situation. I have a large contour plot created with imshow of matplotlib. I then want to use this contour plot to create a large number of png images, where each image is a small section of the contour image by changing the x and y limits and the aspect ratio.
So no plot data is changing in the loop, only the axis limits and the aspect ratio are changing between each png image.
The following MWE creates 70 png images in a "figs" folder demonstrating the simplified idea. About 80% of the runtime is taken up by fig.savefig('figs/'+filename).
I've looked into the following without coming up with an improvement:
An alternative to matplotlib with a focus on speed -- I've struggled to find any examples/documentation of contour/surface plots with similar requirements
Multiprocessing -- Similar questions I've seen here appear to require fig = plt.figure() and ax.imshow to be called within the loop, since fig and ax can't be pickled. In my case this will be more expensive than any speed gains achieved by implementing multiprocessing.
I'd appreciate any insight or suggestions you might have.
import numpy as np
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
import time, os
def make_plot(x, y, fix, ax):
aspect = np.random.random(1)+y/2.0-x
xrand = np.random.random(2)*x
xlim = [min(xrand), max(xrand)]
yrand = np.random.random(2)*y
ylim = [min(yrand), max(yrand)]
filename = '{:d}_{:d}.png'.format(x,y)
ax.set_aspect(abs(aspect[0]))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
fig.savefig('figs/'+filename)
if not os.path.isdir('figs'):
os.makedirs('figs')
data = np.random.rand(25, 25)
fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
# in the real case, imshow is an expensive calculation which can't be put inside the loop
ax.imshow(data, interpolation='nearest')
tstart = time.clock()
for i in range(1, 8):
for j in range(3, 13):
make_plot(i, j, fig, ax)
print('took {:.2f} seconds'.format(time.clock()-tstart))
Since the limitation in this case is the call to plt.savefig() it cannot be optimized a lot. Internally the figure is rendered from scratch and that takes a while. Possibly reducing the number of vertices to be drawn might reduce the time a bit.
The time to run your code on my machine (Win 8, i5 with 4 cores 3.5GHz) is 2.5 seconds. This seems not too bad. One can get a little improvement by using Multiprocessing.
A note about Multiprocessing: It may seem surprising that using the state machine of pyplot inside multiprocessing should work at all. But it does.
And in this case here, since every image is based on the same figure and axes object, one does not even have to create new figures and axes.
I modified an answer I gave here a while ago for your case and the total time is roughly halved using multiprocessing and 5 processes on 4 cores. I appended a barplot which shows the effect of multiprocessing.
import numpy as np
#import matplotlib as mpl
#mpl.use('agg') # use of agg seems to slow things down a bit
import matplotlib.pyplot as plt
import multiprocessing
import time, os
def make_plot(d):
start = time.clock()
x,y=d
#using aspect in this way causes a warning for me
#aspect = np.random.random(1)+y/2.0-x
xrand = np.random.random(2)*x
xlim = [min(xrand), max(xrand)]
yrand = np.random.random(2)*y
ylim = [min(yrand), max(yrand)]
filename = '{:d}_{:d}.png'.format(x,y)
ax = plt.gca()
#ax.set_aspect(abs(aspect[0]))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
plt.savefig('figs/'+filename)
stop = time.clock()
return np.array([x,y, start, stop])
if not os.path.isdir('figs'):
os.makedirs('figs')
data = np.random.rand(25, 25)
fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
ax.imshow(data, interpolation='nearest')
some_list = []
for i in range(1, 8):
for j in range(3, 13):
some_list.append((i,j))
if __name__ == "__main__":
multiprocessing.freeze_support()
tstart = time.clock()
print tstart
num_proc = 5
p = multiprocessing.Pool(num_proc)
nu = p.map(make_plot, some_list)
tooktime = 'Plotting of {} frames took {:.2f} seconds'
tooktime = tooktime.format(len(some_list), time.clock()-tstart)
print tooktime
nu = np.array(nu)
plt.close("all")
fig, ax = plt.subplots(figsize=(8,5))
plt.suptitle(tooktime)
ax.barh(np.arange(len(some_list)), nu[:,3]-nu[:,2],
height=np.ones(len(some_list)), left=nu[:,2], align="center")
ax.set_xlabel("time [s]")
ax.set_ylabel("image number")
ax.set_ylim([-1,70])
plt.tight_layout()
plt.savefig(__file__+".png")
plt.show()
My code is something (roughly) like this:
UPDATE: I've redone this with some actual mock-up code that reflects my general problem. Also, realized that the colorbar creation is in the actual loop as otherwise there's nothing to map it to. Sorry for the code before, typed it up in frantic desperation at the very end of the workday :).
import numpy
import matplotlib as mplot
import matplotlib.pyplot as plt
import os
#make some mock data
x = np.linspace(1,2, 100)
X, Y = np.meshgrid(x, x)
Z = plt.mlab.bivariate_normal(X,Y,1,1,0,0)
fig = plt.figure()
ax = plt.axes()
'''
Do some figure-related stuff that take up a lot of time,
I want to avoid having to do them in the loop over and over again.
They hinge on the presence of fig so I can't make
new figure to save each time or something, I'd have to do
them all over again.
'''
for i in range(1,1000):
plotted = plt.plot(X,Y,Z)
cbar = plt.colorbar(ax=ax, orientation = 'horizontal')
plt.savefig(os.path.expanduser(os.path.join('~/', str(i))))
plt.draw()
mplot.figure.Figure.delaxes(fig, fig.axes[1]) #deletes but whitespace remains
'''
Here I need something to remove the colorbar otherwise
I end up with +1 colorbar on my plot at every iteration.
I've tried various things to remove it BUT it keeps adding whitespace instead
so doesn't actually fix anything.
'''
Has anyone come across this problem before and managed to fix it? Hopefully this is enough
for an idea of the problem, I can post more code if needed but thought it'd be less of a clutter if I just give an overview example.
Thanks.
colorbar() allows you explicitly set which axis to render into - you can use this to ensure that they always appear in the same place, and not steal any space from another axis. Furthermore, you could reset the .mappable attribute of an existing colorbar, rather than redefine it each time.
Example with explicit axes:
x = np.linspace(1,2, 100)
X, Y = np.meshgrid(x, x)
Z = plt.mlab.bivariate_normal(X,Y,1,1,0,0)
fig = plt.figure()
ax1 = fig.add_axes([0.1,0.1,0.8,0.7])
ax2 = fig.add_axes([0.1,0.85,0.8,0.05])
...
for i in range(1,5):
plotted = ax1.pcolor(X,Y,Z)
cbar = plt.colorbar(mappable=plotted, cax=ax2, orientation = 'horizontal')
#note "cax" instead of "ax"
plt.savefig(os.path.expanduser(os.path.join('~/', str(i))))
plt.draw()
I had a very similar problem, which I finally managed to solve by defining a colorbar axes in a similar fashion to:
Multiple imshow-subplots, each with colorbar
The advantage compared to mdurant's answer is that it saves defining the axes location manually.
import matplotlib.pyplot as plt
import IPython.display as display
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pylab import *
%matplotlib inline
def plot_res(ax,cax):
plotted=ax.imshow(rand(10, 10))
cbar=plt.colorbar(mappable=plotted,cax=cax)
fig, axarr = plt.subplots(2, 2)
cax1 = make_axes_locatable(axarr[0,0]).append_axes("right", size="10%", pad=0.05)
cax2 = make_axes_locatable(axarr[0,1]).append_axes("right", size="10%", pad=0.05)
cax3 = make_axes_locatable(axarr[1,0]).append_axes("right", size="10%", pad=0.05)
cax4 = make_axes_locatable(axarr[1,1]).append_axes("right", size="10%", pad=0.05)
# plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.3, hspace=0.3)
N=10
for j in range(N):
plot_res(axarr[0,0],cax1)
plot_res(axarr[0,1],cax2)
plot_res(axarr[1,0],cax3)
plot_res(axarr[1,1],cax4)
display.clear_output(wait=True)
display.display(plt.gcf())
display.clear_output(wait=True)
I'm familiar with the following questions:
Matplotlib savefig with a legend outside the plot
How to put the legend out of the plot
It seems that the answers in these questions have the luxury of being able to fiddle with the exact shrinking of the axis so that the legend fits.
Shrinking the axes, however, is not an ideal solution because it makes the data smaller making it actually more difficult to interpret; particularly when its complex and there are lots of things going on ... hence needing a large legend
The example of a complex legend in the documentation demonstrates the need for this because the legend in their plot actually completely obscures multiple data points.
http://matplotlib.sourceforge.net/users/legend_guide.html#legend-of-complex-plots
What I would like to be able to do is dynamically expand the size of the figure box to accommodate the expanding figure legend.
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(-2*np.pi, 2*np.pi, 0.1)
fig = plt.figure(1)
ax = fig.add_subplot(111)
ax.plot(x, np.sin(x), label='Sine')
ax.plot(x, np.cos(x), label='Cosine')
ax.plot(x, np.arctan(x), label='Inverse tan')
lgd = ax.legend(loc=9, bbox_to_anchor=(0.5,0))
ax.grid('on')
Notice how the final label 'Inverse tan' is actually outside the figure box (and looks badly cutoff - not publication quality!)
Finally, I've been told that this is normal behaviour in R and LaTeX, so I'm a little confused why this is so difficult in python... Is there a historical reason? Is Matlab equally poor on this matter?
I have the (only slightly) longer version of this code on pastebin http://pastebin.com/grVjc007
Sorry EMS, but I actually just got another response from the matplotlib mailling list (Thanks goes out to Benjamin Root).
The code I am looking for is adjusting the savefig call to:
fig.savefig('samplefigure', bbox_extra_artists=(lgd,), bbox_inches='tight')
#Note that the bbox_extra_artists must be an iterable
This is apparently similar to calling tight_layout, but instead you allow savefig to consider extra artists in the calculation. This did in fact resize the figure box as desired.
import matplotlib.pyplot as plt
import numpy as np
plt.gcf().clear()
x = np.arange(-2*np.pi, 2*np.pi, 0.1)
fig = plt.figure(1)
ax = fig.add_subplot(111)
ax.plot(x, np.sin(x), label='Sine')
ax.plot(x, np.cos(x), label='Cosine')
ax.plot(x, np.arctan(x), label='Inverse tan')
handles, labels = ax.get_legend_handles_labels()
lgd = ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5,-0.1))
text = ax.text(-0.2,1.05, "Aribitrary text", transform=ax.transAxes)
ax.set_title("Trigonometry")
ax.grid('on')
fig.savefig('samplefigure', bbox_extra_artists=(lgd,text), bbox_inches='tight')
This produces:
[edit] The intent of this question was to completely avoid the use of arbitrary coordinate placements of arbitrary text as was the traditional solution to these problems. Despite this, numerous edits recently have insisted on putting these in, often in ways that led to the code raising an error. I have now fixed the issues and tidied the arbitrary text to show how these are also considered within the bbox_extra_artists algorithm.
Added: I found something that should do the trick right away, but the rest of the code below also offers an alternative.
Use the subplots_adjust() function to move the bottom of the subplot up:
fig.subplots_adjust(bottom=0.2) # <-- Change the 0.02 to work for your plot.
Then play with the offset in the legend bbox_to_anchor part of the legend command, to get the legend box where you want it. Some combination of setting the figsize and using the subplots_adjust(bottom=...) should produce a quality plot for you.
Alternative:
I simply changed the line:
fig = plt.figure(1)
to:
fig = plt.figure(num=1, figsize=(13, 13), dpi=80, facecolor='w', edgecolor='k')
and changed
lgd = ax.legend(loc=9, bbox_to_anchor=(0.5,0))
to
lgd = ax.legend(loc=9, bbox_to_anchor=(0.5,-0.02))
and it shows up fine on my screen (a 24-inch CRT monitor).
Here figsize=(M,N) sets the figure window to be M inches by N inches. Just play with this until it looks right for you. Convert it to a more scalable image format and use GIMP to edit if necessary, or just crop with the LaTeX viewport option when including graphics.
Here is another, very manual solution. You can define the size of the axis and paddings are considered accordingly (including legend and tickmarks). Hope it is of use to somebody.
Example (axes size are the same!):
Code:
#==================================================
# Plot table
colmap = [(0,0,1) #blue
,(1,0,0) #red
,(0,1,0) #green
,(1,1,0) #yellow
,(1,0,1) #magenta
,(1,0.5,0.5) #pink
,(0.5,0.5,0.5) #gray
,(0.5,0,0) #brown
,(1,0.5,0) #orange
]
import matplotlib.pyplot as plt
import numpy as np
import collections
df = collections.OrderedDict()
df['labels'] = ['GWP100a\n[kgCO2eq]\n\nasedf\nasdf\nadfs','human\n[pts]','ressource\n[pts]']
df['all-petroleum long name'] = [3,5,2]
df['all-electric'] = [5.5, 1, 3]
df['HEV'] = [3.5, 2, 1]
df['PHEV'] = [3.5, 2, 1]
numLabels = len(df.values()[0])
numItems = len(df)-1
posX = np.arange(numLabels)+1
width = 1.0/(numItems+1)
fig = plt.figure(figsize=(2,2))
ax = fig.add_subplot(111)
for iiItem in range(1,numItems+1):
ax.bar(posX+(iiItem-1)*width, df.values()[iiItem], width, color=colmap[iiItem-1], label=df.keys()[iiItem])
ax.set(xticks=posX+width*(0.5*numItems), xticklabels=df['labels'])
#--------------------------------------------------
# Change padding and margins, insert legend
fig.tight_layout() #tight margins
leg = ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0)
plt.draw() #to know size of legend
padLeft = ax.get_position().x0 * fig.get_size_inches()[0]
padBottom = ax.get_position().y0 * fig.get_size_inches()[1]
padTop = ( 1 - ax.get_position().y0 - ax.get_position().height ) * fig.get_size_inches()[1]
padRight = ( 1 - ax.get_position().x0 - ax.get_position().width ) * fig.get_size_inches()[0]
dpi = fig.get_dpi()
padLegend = ax.get_legend().get_frame().get_width() / dpi
widthAx = 3 #inches
heightAx = 3 #inches
widthTot = widthAx+padLeft+padRight+padLegend
heightTot = heightAx+padTop+padBottom
# resize ipython window (optional)
posScreenX = 1366/2-10 #pixel
posScreenY = 0 #pixel
canvasPadding = 6 #pixel
canvasBottom = 40 #pixel
ipythonWindowSize = '{0}x{1}+{2}+{3}'.format(int(round(widthTot*dpi))+2*canvasPadding
,int(round(heightTot*dpi))+2*canvasPadding+canvasBottom
,posScreenX,posScreenY)
fig.canvas._tkcanvas.master.geometry(ipythonWindowSize)
plt.draw() #to resize ipython window. Has to be done BEFORE figure resizing!
# set figure size and ax position
fig.set_size_inches(widthTot,heightTot)
ax.set_position([padLeft/widthTot, padBottom/heightTot, widthAx/widthTot, heightAx/heightTot])
plt.draw()
plt.show()
#--------------------------------------------------
#==================================================
I tried a very simple way, just make the figure a bit wider:
fig, ax = plt.subplots(1, 1, figsize=(a, b))
adjust a and b to a proper value such that the legend is included in the figure