I need to do step by step some numerical calculation algorithms visually, as in the figure below: (gif)
Font
How can I do this animation with matplotlib? Is there any way to visually present these transitions? As transformation of matrices, sum, transposition, using a loop and it presenting the transitions etc.
My goal is not to use graphics but the same matrix representation. This is to facilitate the understanding of the algorithms.
Since matrices can be plotted easily with imshow, one could create such table with an imshow plot and adjust the data according to the current animation step.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.animation
#####################
# Array preparation
#####################
#input array
a = np.random.randint(50,150, size=(5,5))
# kernel
kernel = np.array([[ 0,-1, 0], [-1, 5,-1], [ 0,-1, 0]])
# visualization array (2 bigger in each direction)
va = np.zeros((a.shape[0]+2, a.shape[1]+2), dtype=int)
va[1:-1,1:-1] = a
#output array
res = np.zeros_like(a)
#colorarray
va_color = np.zeros((a.shape[0]+2, a.shape[1]+2))
va_color[1:-1,1:-1] = 0.5
#####################
# Create inital plot
#####################
fig = plt.figure(figsize=(8,4))
def add_axes_inches(fig, rect):
w,h = fig.get_size_inches()
return fig.add_axes([rect[0]/w, rect[1]/h, rect[2]/w, rect[3]/h])
axwidth = 3.
cellsize = axwidth/va.shape[1]
axheight = cellsize*va.shape[0]
ax_va = add_axes_inches(fig, [cellsize, cellsize, axwidth, axheight])
ax_kernel = add_axes_inches(fig, [cellsize*2+axwidth,
(2+res.shape[0])*cellsize-kernel.shape[0]*cellsize,
kernel.shape[1]*cellsize,
kernel.shape[0]*cellsize])
ax_res = add_axes_inches(fig, [cellsize*3+axwidth+kernel.shape[1]*cellsize,
2*cellsize,
res.shape[1]*cellsize,
res.shape[0]*cellsize])
ax_kernel.set_title("Kernel", size=12)
im_va = ax_va.imshow(va_color, vmin=0., vmax=1.3, cmap="Blues")
for i in range(va.shape[0]):
for j in range(va.shape[1]):
ax_va.text(j,i, va[i,j], va="center", ha="center")
ax_kernel.imshow(np.zeros_like(kernel), vmin=-1, vmax=1, cmap="Pastel1")
for i in range(kernel.shape[0]):
for j in range(kernel.shape[1]):
ax_kernel.text(j,i, kernel[i,j], va="center", ha="center")
im_res = ax_res.imshow(res, vmin=0, vmax=1.3, cmap="Greens")
res_texts = []
for i in range(res.shape[0]):
row = []
for j in range(res.shape[1]):
row.append(ax_res.text(j,i, "", va="center", ha="center"))
res_texts.append(row)
for ax in [ax_va, ax_kernel, ax_res]:
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.yaxis.set_major_locator(mticker.IndexLocator(1,0))
ax.xaxis.set_major_locator(mticker.IndexLocator(1,0))
ax.grid(color="k")
###############
# Animation
###############
def init():
for row in res_texts:
for text in row:
text.set_text("")
def animate(ij):
i,j=ij
o = kernel.shape[1]//2
# calculate result
res_ij = (kernel*va[1+i-o:1+i+o+1, 1+j-o:1+j+o+1]).sum()
res_texts[i][j].set_text(res_ij)
# make colors
c = va_color.copy()
c[1+i-o:1+i+o+1, 1+j-o:1+j+o+1] = 1.
im_va.set_array(c)
r = res.copy()
r[i,j] = 1
im_res.set_array(r)
i,j = np.indices(res.shape)
ani = matplotlib.animation.FuncAnimation(fig, animate, init_func=init,
frames=zip(i.flat, j.flat), interval=400)
ani.save("algo.gif", writer="imagemagick")
plt.show()
This example sets up the animation inline in a Jupyter notebook. I suppose there's probably also a way to export as a gif, but I haven't looked into that so far.
Anyway, first thing to do is set up the table. I borrowed heavily from Export a Pandas dataframe as a table image for the render_mpl_table code.
The (adapted) version for this problem is:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import six
width = 8
data = pd.DataFrame([[0]*width,
[0, *np.random.randint(95,105,size=width-2), 0],
[0, *np.random.randint(95,105,size=width-2), 0],
[0, *np.random.randint(95,105,size=width-2), 0]])
def render_mpl_table(data, col_width=3.0, row_height=0.625, font_size=14,
row_color="w", edge_color="black", bbox=[0, 0, 1, 1],
ax=None, col_labels=data.columns,
highlight_color="mediumpurple",
highlights=[], **kwargs):
if ax is None:
size = (np.array(data.shape[::-1]) + np.array([0, 1])) *
np.array([col_width, row_height])
fig, ax = plt.subplots(figsize=size)
ax.axis('off')
mpl_table = ax.table(cellText=data.values, bbox=bbox, colLabels=col_labels,
**kwargs)
mpl_table.auto_set_font_size(False)
mpl_table.set_fontsize(font_size)
for k, cell in six.iteritems(mpl_table._cells):
cell.set_edgecolor(edge_color)
if k in highlights:
cell.set_facecolor(highlight_color)
elif data.iat[k] > 0:
cell.set_facecolor("lightblue")
else:
cell.set_facecolor(row_color)
return fig, ax, mpl_table
fig, ax, mpl_table = render_mpl_table(data, col_width=2.0, col_labels=None,
highlights=[(0,2),(0,3),(1,2),(1,3)])
In this case, the cells to highlight in a different color are given by an array of tuples that specify the row and column.
For the animation, we need to set up a function that draws the table with different highlights:
def update_table(i, *args, **kwargs):
r = i//(width-1)
c = i%(width-1)
highlights=[(r,c),(r,c+1),(r+1,c),(r+1,c+1)]
for k, cell in six.iteritems(mpl_table._cells):
cell.set_edgecolor("black")
if k in highlights:
cell.set_facecolor("mediumpurple")
elif data.iat[k] > 0:
cell.set_facecolor("lightblue")
else:
cell.set_facecolor("white")
return (mpl_table,)
This forcibly updates the colors for all cells in the table. The highlights array is computed based on the current frame. The width and height of the table are kind of hard-coded in this example, but that shouldn't be super hard to change based on the shape of your input data.
We create an animation based on the existing fig and update function:
a = animation.FuncAnimation(fig, update_table, (width-1)*3,
interval=750, blit=True)
And lastly we show it inline in our notebook:
HTML(a.to_jshtml())
I put this together in a notebook on github, see https://github.com/gurudave/so_examples/blob/master/mpl_animation.ipynb
Hope that's enough to get you going in the right direction!
Related
I'm having problems with heatmap.
I create the following function to show the analysis with heatmap
data = [ 0.00662896, -0.00213044, -0.00156812, 0.01450994, -0.00875174, -0.01561342, -0.00694762, 0.00476027, 0.00470659]
def plot_heatmap(pathOut, data, title, fileName, precis=2, show=False):
from matplotlib import cm
fig = plt.figure()
n = int(np.sqrt(len(data)))
data = data.reshape(n,n)
heatmap = plt.pcolor(data,cmap=cm.YlOrBr)
xLabels = (np.linspace(1,n,n,dtype=int))
yLabels = (np.linspace(1,n,n,dtype=int))
xpos = np.linspace(1,n,n)-0.5
ypos = np.linspace(1,n,n)-0.5
for y in range(n):
for x in range(n):
plt.text(x + 0.5, y + 0.5, f'{data[y, x]:.{precis}f}',
horizontalalignment='center',
verticalalignment='center',
)
plt.colorbar(heatmap, format='%.2f')
plt.xticks(xpos,xLabels)
plt.yticks(ypos,yLabels)
plt.title(f'{title}')
if (show == False ):
plt.close(fig)
elif (show == True):
plt.show()
fig.savefig(f'{pathOut}/{fileName}.pdf', format='pdf')
When I call the function the heatmap is created but not correctly, because I would like to show values at a specific precision. I know how to define text precision and scale precision, but how to adjust data precision to generate the correct heatmap?
In the attached figure, I have 7 cells equal to 0, for my desired precision, but the data used has a larger precision what produce different colors.
It is much easier to use seaborn.heatmap, which includes annotations and a colorbar. seaborn is a high-level API for matplotlib.
This significantly reduces the number of lines of code.
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import seaborn as sns
def plot_heatmap(pathOut, fileName, data, title, precis=2, show=False):
n = int(np.sqrt(len(data)))
data = data.reshape(n, n)
xy_labels = range(1, n+1)
fig, ax = plt.subplots(figsize=(8, 6))
p = sns.heatmap(data=data, annot=True, fmt=f'.{precis}g', ax=ax,
cmap=cm.YlOrBr, xticklabels=xy_labels, yticklabels=xy_labels)
ax.invert_yaxis() # invert the axis if desired
ax.set_title(f'{title}')
fig.savefig(f'{pathOut}/{fileName}.pdf', format='pdf')
if (show == False ):
plt.close(fig)
elif (show == True):
plt.show()
data = np.array([ 0.00662896, -0.00213044, -0.00156812, 0.01450994, -0.00875174, -0.01561342, -0.00694762, 0.00476027, 0.00470659])
plot_heatmap('.', 'test', data, 'test', 4, True)
The f-string for plt.txt is not correct. It will be easier to round the value and convert it to a str type.
str(round(data[x, y], precis)) instead of f'{data[y, x]:.{precis}f}'
data[x, y] should be data[y, x]
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
def plot_heatmap(pathOut, fileName, data, title, precis=2, show=False):
fig = plt.figure(figsize=(8, 6))
n = int(np.sqrt(len(data)))
data = data.reshape(n, n)
heatmap = plt.pcolor(data, cmap=cm.YlOrBr)
xLabels = (np.linspace(1,n,n,dtype=int))
yLabels = (np.linspace(1,n,n,dtype=int))
xpos = np.linspace(1,n,n)-0.5
ypos = np.linspace(1,n,n)-0.5
for y in range(n):
for x in range(n):
s = str(round(data[y, x], precis)) # added s for plt.txt and reverse x and y for data addressing
plt.text(x + 0.5, y + 0.5, s,
horizontalalignment='center',
verticalalignment='center',
)
plt.colorbar(heatmap, format=f'%.{precis}f') # add precis to the colorbar
plt.xticks(xpos,xLabels)
plt.yticks(ypos,yLabels)
plt.title(f'{title}')
fig.savefig(f'{pathOut}/{fileName}.pdf', format='pdf') # this should be before plt.show()
if (show == False ):
plt.close(fig)
elif (show == True):
plt.show()
# the function expects an array, not a list
data = np.array([ 0.00662896, -0.00213044, -0.00156812, 0.01450994, -0.00875174, -0.01561342, -0.00694762, 0.00476027, 0.00470659])
# function call
plot_heatmap('.', 'test', data, 'test', 4, True)
I am trying to animate a histogram using matplotlib and I want to show the different bars using a colormap, e.g:
I have this working when I clear the complete figure every frame and then redraw everything. But this is very slow, so I am trying out the example by matplotlib itself.
This works and is very fast, but unfortunately I have no idea on how to specify a colormap because it is using the patches.PathPatch object to draw the histogram now. I can only get it to work with the same single color for every individual bar.
How can I specify a gradient or colormap to achieve the desired result shown above?
Here is an example of a working animation with a single color which I am currently using.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.path as path
import matplotlib.animation as animation
# Fixing random state for reproducibility
np.random.seed(19680801)
# histogram our data with numpy
data = np.random.randn(1000)
n, bins = np.histogram(data, 100)
# get the corners of the rectangles for the histogram
left = np.array(bins[:-1])
right = np.array(bins[1:])
bottom = np.zeros(len(left))
top = bottom + n
nrects = len(left)
nverts = nrects * (1 + 3 + 1)
verts = np.zeros((nverts, 2))
codes = np.ones(nverts, int) * path.Path.LINETO
codes[0::5] = path.Path.MOVETO
codes[4::5] = path.Path.CLOSEPOLY
verts[0::5, 0] = left
verts[0::5, 1] = bottom
verts[1::5, 0] = left
verts[1::5, 1] = top
verts[2::5, 0] = right
verts[2::5, 1] = top
verts[3::5, 0] = right
verts[3::5, 1] = bottom
patch = None
def animate(i):
# simulate new data coming in
data = np.random.randn(1000)
n, bins = np.histogram(data, 100)
top = bottom + n
verts[1::5, 1] = top
verts[2::5, 1] = top
return [patch, ]
fig, ax = plt.subplots()
barpath = path.Path(verts, codes)
patch = patches.PathPatch(
barpath, facecolor='green', edgecolor='yellow', alpha=0.5)
ax.add_patch(patch)
ax.set_xlim(left[0], right[-1])
ax.set_ylim(bottom.min(), top.max())
ani = animation.FuncAnimation(fig, animate, 100, repeat=False, blit=True)
plt.show()
I recommend u using BarContainer, you can change bar color individually. In your example, the path is single object, matplotlib seems not to support gradient color for a single patch (not sure though).
import numpy as np
import matplotlib.pyplot as plt
# Fixing random state for reproducibility
np.random.seed(19680801)
# histogram our data with numpy
data = np.random.randn(1000)
colors = plt.cm.coolwarm(np.linspace(0, 1, 100))
def animate(i):
data = np.random.randn(1000)
bc = ax.hist(data, 100)[2]
for i, e in enumerate(bc):
e.set_color(colors[i])
return bc
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))
ani = animation.FuncAnimation(fig, animate, 100, repeat=False, blit=True)
I have a 4D data set (for those who care, its an astronomical Position-Position-Temperature-Opacity image) in a numpy array, that I need to plot in an interactive way. While there are programs to do this, none of them can handle the unusual form that my data steps in (but I can worry about that, thats not part of the question).
I know how to get it plotting with one Slider, but really I need to plot the image with 2 Sliders, one for each of temperature and opacity.
My MWE of a 3D array code is below:
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
array = np.random.rand(300,300,10)
axis = 2
s = [slice(0, 1) if i == axis else slice(None) for i in xrange(array.ndim)]
im = array[s].squeeze()
fig = plt.figure()
ax = plt.subplot(111)
l = ax.imshow(im, origin = 'lower')
axcolor = 'lightgoldenrodyellow'
ax = fig.add_axes([0.2, 0.95, 0.65, 0.03], axisbg=axcolor)
slider = Slider(ax, 'Temperature', 0, array.shape[axis] - 1,
valinit=0, valfmt='%i')
def update(val):
ind = int(slider.val)
s = [slice(ind, ind + 1) if i == axis else slice(None)
for i in xrange(array.ndim)]
im = array[s].squeeze()
l.set_data(im)
fig.canvas.draw()
slider.on_changed(update)
plt.show()
Any way to do it with 2 sliders?
EDIT: The problem I am having is I dont know how to expand to 2 sliders. Particularly how to adapt the line
s = [slice(0, 1) if i == axis else slice(None) for i in xrange(array.ndim)]
and how to modify the update function when I go from np.random.rand(300,300,10) to np.random.rand(300,300,10,10). I supposed I will need to declare both a T_axis = 2 and B_axis = 3 rather than simply an axis = 2, but beyond that, I am rather stuck as to how to modify it.
As I interprete the data structure, you have an array of shape (300,300,n,m), where n is the number of temperatures and m is the number of opacities. The image to show for the ith temperature and the jth opacity is hence, array[:,:,i,j].
You now need of course two different silders where one determines the value of i and the other of j.
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
array = np.random.rand(300,300,10,9)
# assuming you have for each i=Temperature index and j =Opacity index
# an image array(:,:,i,j)
fig, ax = plt.subplots()
l = ax.imshow(array[:,:,0,0], origin = 'lower')
axT = fig.add_axes([0.2, 0.95, 0.65, 0.03])
axO = fig.add_axes([0.2, 0.90, 0.65, 0.03])
sliderT = Slider(axT, 'Temperature', 0, array.shape[2]-1, valinit=0, valfmt='%i')
sliderO = Slider(axO, 'Opacity', 0, array.shape[3]-1, valinit=0, valfmt='%i')
def update(val):
i = int(sliderT.val)
j = int(sliderO.val)
im = array[:,:,i,j]
l.set_data(im)
fig.canvas.draw_idle()
sliderT.on_changed(update)
sliderO.on_changed(update)
plt.show()
I'm trying to create a scrollable multiplot based on the answer to this question:
Creating a scrollable multiplot with python's pylab
Lines created using ax.plot() are updating correctly, however I'm unable to figure out how to update artists created using xvlines() and fill_between().
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider
#create dataframes
dfs={}
for x in range(100):
col1=np.random.normal(10,0.5,30)
col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
col3=np.random.randint(4,size=30)
dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})
#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])
#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')
#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')
#artists
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')
#update plots
def update(val):
frame = np.floor(sframe.val)
ln1.set_ydata(dfs[frame]['col1'])
ln2.set_ydata(dfs[frame]['col2'])
ax.set_title('Frame ' + str(int(frame)))
plt.draw()
#connect callback to slider
sframe.on_changed(update)
plt.show()
This is what it looks like at the moment
I can't apply the same approach as for plot(), since the following produces an error message:
ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
TypeError: 'PolyCollection' object is not iterable
This is what it's meant to look like on each frame
fill_between returns a PolyCollection, which expects a list (or several lists) of vertices upon creation. Unfortunately I haven't found a way to retrieve the vertices that where used to create the given PolyCollection, but in your case it is easy enough to create the PolyCollection directly (thereby avoiding the use of fill_between) and then update its vertices upon frame change.
Below a version of your code that does what you are after:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider
from matplotlib.collections import PolyCollection
#create dataframes
dfs={}
for x in range(100):
col1=np.random.normal(10,0.5,30)
col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]
col3=np.random.randint(4,size=30)
dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})
#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])
#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')
#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')
##additional code to update the PolyCollections
val_r = 5
val_b = 8
val_g = 7
def update_collection(collection, value, frame = 0):
xs = np.array(dfs[frame].index)
ys = np.array(dfs[frame]['col2'])
##we need to catch the case where no points with y == value exist:
try:
minx = np.min(xs[ys == value])
maxx = np.max(xs[ys == value])
miny = value-0.5
maxy = value+0.5
verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])
except ValueError:
verts = np.zeros((0,2))
finally:
collection.set_verts([verts])
#artists
##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)
ax.add_collection(reds)
update_collection(reds,val_r)
##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)
ax.add_collection(blues)
update_collection(blues, val_b)
##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)
ax.add_collection(greens)
update_collection(greens, val_g)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')
#update plots
def update(val):
frame = np.floor(sframe.val)
ln1.set_ydata(dfs[frame]['col1'])
ln2.set_ydata(dfs[frame]['col2'])
ax.set_title('Frame ' + str(int(frame)))
##updating the PolyCollections:
update_collection(reds,val_r, frame)
update_collection(blues,val_b, frame)
update_collection(greens,val_g, frame)
plt.draw()
#connect callback to slider
sframe.on_changed(update)
plt.show()
Each of the three PolyCollections (reds, blues, and greens) has only four vertices (the edges of the rectangles), which are determined based on the given data (which is done in update_collections). The result looks like this:
Tested in Python 3.5
Your error
TypeError: 'PolyCollection' object is not iterable
can be avoided by removing the comma after l3:
l3 = ax.fill_between(xx, y1, y2, **kwargs)
The return value is a PolyCollection, you need to update its vertices during the update() function. An alternative to the other answer posted here is to make fill_between() give you a new PolyCollection, and then get its vertices and use them to update those of l3:
def update(val):
dummy_l3 = ax.fill_between(xx, y1, y2, **kwargs)
verts = [ path._vertices for path in dummy_l3.get_paths() ]
codes = [ path._codes for path in dummy_l3.get_paths() ]
dummy_l3.remove()
l3.set_verts_and_codes(verts, codes)
plt.draw()
The above code does not run for me; however, to refresh fill_between the following works for me
%matplotlib inline
import numpy as np
from IPython import display
import matplotlib.pyplot as plt
import time
hdisplay = display.display("", display_id=True)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
x = np.linspace(0,1,100)
ax.set_title("Test")
ax.set_xlim()
y = np.random.random(size=(100))
dy = 0.1
l = ax.plot(x,y,":",color="red")
b = ax.fill_between( x, y-dy, y+dy, color="red", alpha=0.2 )
hdisplay.update(fig)
for i in range(5):
time.sleep(1)
ax.set_title("Test %ld" % i)
y = np.random.random(size=(100))
l[0].set_ydata( y )
b.remove()
b = ax.fill_between( x, y-dy, y+dy, color="red", alpha=0.2 )
plt.draw()
hdisplay.update(fig)
plt.close(fig)
I have the following code that displays the numerical values of a matrix in a matplotlib.table object:
fig = plt.figure(figsize=(20,11))
plt.title('Correlation Matrix')
ticks = np.array(['$F_{sum}$','$F_{dif}$','$x_{sum}$','$x_{dif}$','$y_{sum}$','$y_{dif}$','$HLR_a$','$e1_a$','$e2_a$',
'$HLR_b$','$e1_b$','$e2_b$'])
ticks = ticks[::-1]
ticks = ticks.tolist()
plt.xticks([0.5,1.2,2.1,3.0,3.9,4.8,5.7,6.6,7.5,8.4,9.3,10],ticks,fontsize=15)
plt.yticks([0.5,1.2,2.1,3.0,3.9,4.8,5.7,6.6,7.5,8.4,9.3,10],['$F_{sum}$','$F_{dif}$','$x_{sum}$','$x_{dif}$','$y_{sum}$','$y_{dif}$','$HLR_a$','$e1_a$','$e2_a$',
'$HLR_b$','$e1_b$','$e2_b$'],fontsize=15)
round_mat = np.round(correlation_mat,2)
table = plt.table(cellText=round_mat,loc='center',colWidths=np.ones(correlation_mat.shape[0])/correlation_mat.shape[0],cellLoc='center',bbox=[0,0,1,1])
table.set_fontsize(25)
plt.show()
with the following output:
I want the x-axis and the y-axis ticks to be centered for each rectangle. Here, it seems that the first few ticks are correct and then the rest spread out. I would like them all equally spaced with the tick at the center. I am not sure what to do for this.
One way to do this is to use the row and column labels for the table. By default, they'll have a background and border, which is a touch clunky to turn off:
import numpy as np
import matplotlib.pyplot as plt
# Generate some data...
data = np.random.random((12, 10))
correlation_mat = np.cov(data)
correlation_mat /= np.diag(correlation_mat)
fig, ax = plt.subplots(figsize=(20,11))
ax.set_title('Correlation Matrix')
ticks = ['$F_{sum}$', '$F_{dif}$', '$x_{sum}$', '$x_{dif}$', '$y_{sum}$',
'$y_{dif}$', '$HLR_a$', '$e1_a$', '$e2_a$', '$HLR_b$', '$e1_b$',
'$e2_b$'][::-1]
round_mat = np.round(correlation_mat, 2)
table = ax.table(cellText=round_mat, cellLoc='center', bbox=[0, 0, 1, 1],
rowLabels=ticks, colLabels=ticks)
table.set_fontsize(25)
ax.axis('off')
for key, cell in table.get_celld().iteritems():
if key[0] == 0 or key[1] == -1:
cell.set(facecolor='none', edgecolor='none')
if key[1] == -1:
cell._loc = 'right'
elif key[0] == 0:
cell._loc = 'center'
plt.show()
However, it's sometimes easier to skip using a table for this altogether:
import numpy as np
import matplotlib.pyplot as plt
# Generate some data...
data = np.random.random((12, 10))
correlation_mat = np.cov(data)
correlation_mat /= np.diag(correlation_mat)
num = data.shape[0]
fig, ax = plt.subplots(figsize=(20,11))
ticks = ['$F_{sum}$', '$F_{dif}$', '$x_{sum}$', '$x_{dif}$', '$y_{sum}$',
'$y_{dif}$', '$HLR_a$', '$e1_a$', '$e2_a$', '$HLR_b$', '$e1_b$',
'$e2_b$']
ticks = ticks[::-1]
ax.matshow(correlation_mat, aspect='auto', cmap='cool')
ax.set(title='Correlation Matrix', xticks=range(num), xticklabels=ticks,
yticks=range(num), yticklabels=ticks)
ax.tick_params(labelsize=25)
for (i, j), val in np.ndenumerate(correlation_mat):
ax.annotate('{:0.2f}'.format(val), (j,i), ha='center', va='center', size=25)
plt.show()