I made a code for sub plotting
fig, axs = plt.subplots(2, 2, figsize = (20,10))
candlestick_ohlc(axs[0,0],df.values,width=0.6,colorup='green', colordown='red', alpha=0.5);
candlestick_ohlc(axs[0,0],df1.values,width=0.6,colorup='red', colordown='green', alpha=0.8);
date_format = mpl_dates.DateFormatter('%d %b %Y')
axs[0,0].xaxis.set_major_formatter(date_format)
axs[0,0].xaxis.set_major_formatter(date_format)
for level in levels:
axs[0,0].hlines(level[1],xmin=df['Date'][level[0]],\
xmax=max(df['Date']),colors='black')
fig.autofmt_xdate()
axs[1,1].plot(ichi['Close'],label = 'Close')
axs[1,0].scatter(df.index, df.Close, c = squeeze['signal'])
axs[1,1].fill_between(ichi.index, ichi['ISA_9'],ichi['ISB_26'], where = ichi['ISA_9']>ichi['ISB_26'], facecolor = 'green', alpha = 0.5)
axs[1,1].fill_between(ichi.index, ichi['ISA_9'],ichi['ISB_26'], where = ichi['ISA_9']<ichi['ISB_26'], facecolor = 'red', alpha = 0.5)
axs[1,1].legend()
And i am quite satisfied with this
My subplot
However, I wanted to add one more plot at axs[0,1] for which I used trendln package for plotting support and resistance
plt.figure(figsize = (20,10))
f= trendln.plot_support_resistance(hist[-100:].Close,accuracy = 10)
plt.show()
plt.clf() #clear figure
Support resistance plot
Is there any way such that I can incorporate support resistance plot into my initial plot axs[0,1]
Unfortunately, reading the source code of trendln, they directly use plt.plot for everything, so it's not easy to do this. You have to change the source code yourself. You can see where the source is located:
>>> import trendln
>>> trendln.__file__
'/home/username/.local/lib/python3.8/site-packages/trendln/__init__.py'
>>>
Then you can directly modify the plot_support_resistance function to the following. I basically make it take an axs and plot there instead of plt; there were also a few other changes to be made:
def plot_support_resistance(axs, hist, xformatter = None, numbest = 2, fromwindows = True,
pctbound=0.1, extmethod = METHOD_NUMDIFF, method=METHOD_NSQUREDLOGN,
window=125, errpct = 0.005, hough_scale=0.01, hough_prob_iter=10, sortError=False, accuracy=1):
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
ret = calc_support_resistance(hist, extmethod, method, window, errpct, hough_scale, hough_prob_iter, sortError, accuracy)
# plt.clf()
# plt.subplot(111)
if len(ret) == 2:
minimaIdxs, pmin, mintrend, minwindows = ret[0]
maximaIdxs, pmax, maxtrend, maxwindows = ret[1]
if type(hist) is tuple and len(hist) == 2 and check_num_alike(hist[0]) and check_num_alike(hist[1]):
len_h = len(hist[0])
min_h, max_h = min(min(hist[0]), min(hist[1])), max(max(hist[0]), max(hist[1]))
disp = [(hist[0], minimaIdxs, pmin, 'yo', 'Avg. Support', 'y--'), (hist[1], maximaIdxs, pmax, 'bo', 'Avg. Resistance', 'b--')]
dispwin = [(hist[0], minwindows, 'Support', 'g--'), (hist[1], maxwindows, 'Resistance', 'r--')]
disptrend = [(hist[0], mintrend, 'Support', 'g--'), (hist[1], maxtrend, 'Resistance', 'r--')]
axs.plot(range(len_h), hist[0], 'k--', label='Low Price')
axs.plot(range(len_h), hist[1], 'm--', label='High Price')
else:
len_h = len(hist)
min_h, max_h = min(hist), max(hist)
disp = [(hist, minimaIdxs, pmin, 'yo', 'Avg. Support', 'y--'), (hist, maximaIdxs, pmax, 'bo', 'Avg. Resistance', 'b--')]
dispwin = [(hist, minwindows, 'Support', 'g--'), (hist, maxwindows, 'Resistance', 'r--')]
disptrend = [(hist, mintrend, 'Support', 'g--'), (hist, maxtrend, 'Resistance', 'r--')]
axs.plot(range(len_h), hist, 'k--', label='Close Price')
else:
minimaIdxs, pmin, mintrend, minwindows = ([], [], [], []) if hist[0] is None else ret
maximaIdxs, pmax, maxtrend, maxwindows = ([], [], [], []) if hist[1] is None else ret
len_h = len(hist[1 if hist[0] is None else 0])
min_h, max_h = min(hist[1 if hist[0] is None else 0]), max(hist[1 if hist[0] is None else 0])
disp = [(hist[1], maximaIdxs, pmax, 'bo', 'Avg. Resistance', 'b--') if hist[0] is None else (hist[0], minimaIdxs, pmin, 'yo', 'Avg. Support', 'y--')]
dispwin = [(hist[1], maxwindows, 'Resistance', 'r--') if hist[0] is None else (hist[0], minwindows, 'Support', 'g--')]
disptrend = [(hist[1], maxtrend, 'Resistance', 'r--') if hist[0] is None else (hist[0], mintrend, 'Support', 'g--')]
axs.plot(range(len_h), hist[1 if hist[0] is None else 0], 'k--', label= ('High' if hist[0] is None else 'Low') + ' Price')
for h, idxs, pm, clrp, lbl, clrl in disp:
axs.plot(idxs, [h[x] for x in idxs], clrp)
axs.plot([0, len_h-1],[pm[1],pm[0] * (len_h-1) + pm[1]],clrl, label=lbl)
def add_trend(h, trend, lbl, clr, bFirst):
for ln in trend[:numbest]:
maxx = ln[0][-1]+1
while maxx < len_h:
ypred = ln[1][0] * maxx + ln[1][1]
if (h[maxx] > ypred and h[maxx-1] < ypred or h[maxx] < ypred and h[maxx-1] > ypred or
ypred > max_h + (max_h-min_h)*pctbound or ypred < min_h - (max_h-min_h)*pctbound): break
maxx += 1
x_vals = np.array((ln[0][0], maxx)) # plt.gca().get_xlim())
y_vals = ln[1][0] * x_vals + ln[1][1]
if bFirst:
axs.plot([ln[0][0], maxx], y_vals, clr, label=lbl)
bFirst = False
else: axs.plot([ln[0][0], maxx], y_vals, clr)
return bFirst
if fromwindows:
for h, windows, lbl, clr in dispwin:
bFirst = True
for trend in windows:
bFirst = add_trend(h, trend, lbl, clr, bFirst)
else:
for h, trend, lbl, clr in disptrend:
add_trend(h, trend, lbl, clr, True)
# axs.title('Prices with Support/Resistance Trend Lines')
#axs.xlabel('Date')
#axs.ylabel('Price')
axs.legend()
#plt.gca()
axs.xaxis.set_major_locator(ticker.MaxNLocator(6))
#plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
#if not xformatter is None: plt.gca().xaxis.set_major_formatter(xformatter)
#plt.setp(plt.gca().get_xticklabels(), rotation=30, ha='right')
#plt.gca().set_position([0, 0, 1, 1])
#plt.savefig(os.path.join(curdir, 'data', 'suppres.svg'), format='svg', bbox_inches = 'tight')
#plt.show()
return None#plt.gcf()
Now that that is taken care of, you can pass one of the axs to draw things:
import matplotlib.pyplot as plt
import trendln
import yfinance as yf
fig, axs = plt.subplots(2, 2, figsize = (20,10))
axs[0, 1].plot([0, 1], [3, 4])
tick = yf.Ticker('^GSPC') # S&P500
hist = tick.history(period="max", rounding=True)
f = trendln.plot_support_resistance(axs[0, 0], hist[-1000:].Close, accuracy = 2)
plt.show()
I get:
I hope this helps. You probably were looking for another option, but because of the hard coding they use, it's not easy. I also tried copying the axs that plt draws to instead of modifying the source code, but it didn't work.
I wrote a script with annotations that get displayed upon hovering over data points based on some of the answers to similar questions by the user ImportanceOfBeingErnest. One of the changes I've made is that I only change the text and position of a single annotation and use it for more than one data set. This seems to cause the problem that the annotation only gets displayed for the last data set (or plotter, as I called them in my script) in the list of all data sets/ plotters.
How can I get the annotation to display for all data points of both scatter plots in my script? Do I have to make a new annotation for each data set and update them separately?
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from scipy.stats import linregress
# All data in pA*s
gc_data = {
'KAL1':{'Toluol':400754.594,'1-Octen':53695.014,'Decan':6443.483,'1-Nonannitril':48984.504},
'KAL2':{'Toluol':417583.343,'1-Octen':29755.3,'Decan':16264.896,'1-Nonannitril':16264.896},
'KAL3':{'Toluol':442378.88,'1-Octen':18501.12,'Decan':19226.245,'1-Nonannitril':16200.611},
'KAL4':{'Toluol':389679.589,'1-Octen':13381.415,'Decan':68549.002,'1-Nonannitril':11642.123},
'KAL5':{'Toluol':423982.487,'1-Octen':6263.286,'Decan':53580.809,'1-Nonannitril':4946.271},
'KAL6':{'Toluol':351754.329,'1-Octen':8153.602,'Decan':105408.823,'1-Nonannitril':7066.718}
}
# All data in mg
mass_data = {
'KAL1':{'1-Octen':149.3,'Decan':17.8,'1-Nonannitril':154.7},
'KAL2':{'1-Octen':80.6,'Decan':43.7,'1-Nonannitril':82.8},
'KAL3':{'1-Octen':50.4,'Decan':51.8,'1-Nonannitril':51.5},
'KAL4':{'1-Octen':40.9,'Decan':206.9,'1-Nonannitril':40.8},
'KAL5':{'1-Octen':18.0,'Decan':155.2,'1-Nonannitril':16.4},
'KAL6':{'1-Octen':23.4,'Decan':301.4,'1-Nonannitril':23.6},
}
def update_annot(line, annot, ind):
if isinstance(line, matplotlib.collections.PathCollection):
x,y = line.get_offsets().transpose()
elif isinstance(line, matplotlib.lines.Line2D):
x,y = line.get_data()
else:
quit('No getter of x,y Data for this type of plotter.')
annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]])
text = "x = {}\ny= {}".format(x[ind["ind"][0]], y[ind["ind"][0]])
annot.set_text(text)
def hover(event,fig,annot):
if event.inaxes in fig.axes:
plotters = fig.axes[0].collections
for plotter in plotters:
cont, ind = plotter.contains(event)
if cont:
update_annot(plotter, annot, ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if annot.get_visible():
annot.set_visible(False)
fig.canvas.draw_idle()
def get_data(substance,standard):
m = [mass_data[i][substance]/mass_data[i][standard] for i in mass_data]
A = [gc_data[i][substance]/gc_data[i][standard] for i in mass_data]
return A,m
def plot(substance,standard,save=None):
A,m = get_data(substance,standard)
A_baddata = A.pop(1)
m_baddata = m.pop(1)
# Linear regression
a,b,rval,pval,stdev = linregress(A,m)
# Plotting
fig, ax = plt.subplots(figsize=(6,6))
# Data inputs
ax.scatter(A,m,marker='o') # Measured data
ax.scatter(A_baddata,m_baddata,marker='o',c='r')
xmin,xmax = ax.get_xlim()
ymin,ymax = ax.get_ylim()
ax.plot(np.array([-2*max(A),2*max(A)]),np.array([-2*max(A),2*max(A)])*a + b) # graph from regression parameters
ax.set_ylim(ymin,ymax)
ax.set_xlim(xmin,xmax)
# General formatting
ax.tick_params(axis='both',which='both',labelsize=12,direction='in')
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.xaxis.set_minor_locator(AutoMinorLocator())
ax.yaxis.set_minor_locator(AutoMinorLocator())
ax.set_ylabel(r'$m_{\mathrm{Substanz}}\quad/\quadm_{\mathrm{Standard}}$')
ax.set_xlabel(r'$A_{\mathrm{Substanz}}\quad/\quadA_{\mathrm{Standard}}$')
# Description Box
textstr='{}{}\n'.format('Substanz: ',substance)
textstr+='{}{}\n'.format('Standard: ',standard)
textstr+='{}{:.5f}\n'.format('a = ',a)
textstr+='{}{:.5f}\n'.format('b = ',b)
textstr+='{}{:.5f}\n'.format(r'$R^{2}$ = ',rval)
textstr+='{}{:.5f}\n'.format(r'$p$ = ',pval)
textstr+='{}{:.5f}'.format(r'$\bar X = $',stdev)
props = dict(boxstyle='round', fc='#96FBFF', ec='#3CF8FF', alpha=0.5)
ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
verticalalignment='top', bbox=props)
if save:
plt.savefig(substance+'.svg' ,bbox_inches='tight', transparent=True)
else:
# Hovering annotation
################################################################################################
# for i in range(len())
annot = ax.annotate("", xy=(0,0), xytext=(1,1),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w", alpha=0.4),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
################################################################################################
fig.canvas.mpl_connect("motion_notify_event", lambda event: hover(event, fig, annot))
plt.show()
plot('1-Nonannitril','Decan',0)
The main problem is that the hover event gets triggered by the line instead of by the nearby scatter dots. So, this line should be excluded when connecting the motion_notify_event.
Since ImportanceOfBeingErnest's and others posts about how to create annotations, they developed the mplcursors library to strongly simplify the creation of this kind of annotations.
With mplcursors you can simply call mplcursors.cursor(ax.collections, hover=True) and automatically an annotation with x and y positions would be created. But easily can go much further. The example below also shows how to display the artist's label (here the 'artist' is one collection of scatter dots). Also, how to use the artist's color for the background of the annotation. Further, an extra attribute is added to the artist with a list of names. These names are then added to the annotation.
The code leaves out some of the elements that aren't relevant for the annotations, such as the large text.
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from scipy.stats import linregress
import mplcursors
from matplotlib.colors import to_rgb
# All data in pA*s
gc_data = {
'KAL1': {'Toluol': 400754.594, '1-Octen': 53695.014, 'Decan': 6443.483, '1-Nonannitril': 48984.504},
'KAL2': {'Toluol': 417583.343, '1-Octen': 29755.3, 'Decan': 16264.896, '1-Nonannitril': 16264.896},
'KAL3': {'Toluol': 442378.88, '1-Octen': 18501.12, 'Decan': 19226.245, '1-Nonannitril': 16200.611},
'KAL4': {'Toluol': 389679.589, '1-Octen': 13381.415, 'Decan': 68549.002, '1-Nonannitril': 11642.123},
'KAL5': {'Toluol': 423982.487, '1-Octen': 6263.286, 'Decan': 53580.809, '1-Nonannitril': 4946.271},
'KAL6': {'Toluol': 351754.329, '1-Octen': 8153.602, 'Decan': 105408.823, '1-Nonannitril': 7066.718}
}
# All data in mg
mass_data = {
'KAL1': {'1-Octen': 149.3, 'Decan': 17.8, '1-Nonannitril': 154.7},
'KAL2': {'1-Octen': 80.6, 'Decan': 43.7, '1-Nonannitril': 82.8},
'KAL3': {'1-Octen': 50.4, 'Decan': 51.8, '1-Nonannitril': 51.5},
'KAL4': {'1-Octen': 40.9, 'Decan': 206.9, '1-Nonannitril': 40.8},
'KAL5': {'1-Octen': 18.0, 'Decan': 155.2, '1-Nonannitril': 16.4},
'KAL6': {'1-Octen': 23.4, 'Decan': 301.4, '1-Nonannitril': 23.6},
}
def update_annot(sel):
x, y = sel.target
label = sel.artist.get_label()
new_text = f'{label}\nx: {x:.2f}\ny: {y:.2f}'
# append the name
new_text += '\n' + sel.artist.data_names[sel.target.index]
sel.annotation.set_text(new_text)
# get the color of the scatter dots, make them whiter and use that as background color for the annotation
r, g, b = to_rgb(sel.artist.get_facecolor())
sel.annotation.get_bbox_patch().set(fc=((r + 2) / 3, (g + 2) / 3, (b + 2) / 3), alpha=0.7)
def get_data(substance, standard):
m = [mass_data[i][substance] / mass_data[i][standard] for i in mass_data]
A = [gc_data[i][substance] / gc_data[i][standard] for i in mass_data]
return A, m
def plot(substance, standard, save=None):
global measured_names, baddata_names
A, m = get_data(substance, standard)
measured_names = list(mass_data.keys())
A_baddata = A.pop(1)
m_baddata = m.pop(1)
baddata_names = [measured_names.pop(1)]
# Linear regression
a, b, rval, pval, stdev = linregress(A, m)
# Plotting
fig, ax = plt.subplots(figsize=(6, 6))
# Data inputs
scat1 = ax.scatter(A, m, marker='o', label='Measured data') # Measured data
scat1.data_names = measured_names
scat2 = ax.scatter(A_baddata, m_baddata, marker='o', c='r', label='Bad data')
scat2.data_names = baddata_names
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
ax.plot(np.array([-2 * max(A), 2 * max(A)]),
np.array([-2 * max(A), 2 * max(A)]) * a + b) # graph from regression parameters
ax.set_ylim(ymin, ymax)
ax.set_xlim(xmin, xmax)
ax.set_ylabel(r'$m_{\mathrm{Substanz}}\quad/\quadm_{\mathrm{Standard}}$')
ax.set_xlabel(r'$A_{\mathrm{Substanz}}\quad/\quadA_{\mathrm{Standard}}$')
# Hovering annotation
# cursor = mplcursors.cursor(ax.collections, hover=True)
cursor = mplcursors.cursor([scat1, scat2], hover=True)
cursor.connect("add", update_annot)
plt.show()
plot('1-Nonannitril', 'Decan', 0)
I am trying to animate a scatter and bivariate gaussian distribution from a set of xy coordinates. I'll record the specific code that calls the scatter and distribution first and then how I measure the distribution afterwards.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as sts
import matplotlib.animation as animation
''' Below is a section of the script that generates the scatter and contour '''
fig, ax = plt.subplots(figsize = (10,4))
def plotmvs(df, xlim=None, ylim=None, fig=fig, ax=ax):
if xlim is None: xlim = datalimits(df['X'])
if ylim is None: ylim = datalimits(df['Y'])
PDFs = []
for (group,gdf),color in zip(df.groupby('group'), ('red', 'blue')):
ax.plot(*gdf[['X','Y']].values.T, '.', c=color, alpha = 0.5)
kwargs = {
'xlim': xlim,
'ylim': ylim
}
X, Y, PDF = mvpdfs(gdf['X'].values, gdf['Y'].values, **kwargs)
PDFs.append(PDF)
PDF = PDFs[0] - PDFs[1]
normPDF = PDF - PDF.min()
normPDF = normPDF/normPDF.max()
cfs = ax.contourf(X, Y, normPDF, levels=100, cmap='jet')
return fig, ax
n = 10
time = [1]
d = ({
'A1_Y' : [10,20,15,20,25,40,50,60,61,65],
'A1_X' : [15,10,15,20,25,25,30,40,60,61],
'A2_Y' : [10,13,17,10,20,24,29,30,33,40],
'A2_X' : [10,13,15,17,18,19,20,21,26,30],
'A3_Y' : [11,12,15,17,19,20,22,25,27,30],
'A3_X' : [15,18,20,21,22,28,30,32,35,40],
'A4_Y' : [15,20,15,20,25,40,50,60,61,65],
'A4_X' : [16,20,15,30,45,30,40,10,11,15],
'B1_Y' : [18,10,11,13,18,10,30,40,31,45],
'B1_X' : [17,20,15,10,25,20,10,12,14,25],
'B2_Y' : [13,10,14,20,21,12,30,20,11,35],
'B2_X' : [12,20,16,22,15,20,10,20,16,15],
'B3_Y' : [15,20,15,20,25,10,20,10,15,25],
'B3_X' : [18,15,13,20,21,10,20,10,11,15],
'B4_Y' : [19,12,15,18,14,19,13,12,11,18],
'B4_X' : [20,10,12,18,17,15,13,14,19,13],
})
tuples = [((t, k.split('_')[0][0], int(k.split('_')[0][1:]), k.split('_')[1]), v[i]) for k,v in d.items() for i,t in enumerate(time)]
df = pd.Series(dict(tuples)).unstack(-1)
df.index.names = ['time', 'group', 'id']
for time,tdf in df.groupby('time'):
plotmvs(tdf)
'''MY ATTEMPT AT ANIMATING THE PLOT '''
def animate(i) :
tdf.set_offsets([[tdf.iloc[0:,1][0+i][0], tdf.iloc[0:,0][0+i][0]], [tdf.iloc[0:,1][0+i][1], tdf.iloc[0:,0][0+i][1]], [tdf.iloc[0:,1][0+i][2], tdf.iloc[0:,0][0+i][2]], [tdf.iloc[0:,1][0+i][3], tdf.iloc[0:,0][0+i][3]], [tdf.iloc[0:,1][0+i][4], tdf.iloc[0:,0][0+i][4]]])
normPDF = n[i,:,0,:].T
cfs.set_data(X, Y, normPDF)
ani = animation.FuncAnimation(fig, animate, np.arange(0,10),# init_func = init,
interval = 10, blit = False)
A full working code on how the distribution is generated and plotted using a single frame
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as sts
import matplotlib.animation as animation
def datalimits(*data, pad=.15):
dmin,dmax = min(d.min() for d in data), max(d.max() for d in data)
spad = pad*(dmax - dmin)
return dmin - spad, dmax + spad
def rot(theta):
theta = np.deg2rad(theta)
return np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
def getcov(radius=1, scale=1, theta=0):
cov = np.array([
[radius*(scale + 1), 0],
[0, radius/(scale + 1)]
])
r = rot(theta)
return r # cov # r.T
def mvpdf(x, y, xlim, ylim, radius=1, velocity=0, scale=0, theta=0):
X,Y = np.meshgrid(np.linspace(*xlim), np.linspace(*ylim))
XY = np.stack([X, Y], 2)
x,y = rot(theta) # (velocity/2, 0) + (x, y)
cov = getcov(radius=radius, scale=scale, theta=theta)
PDF = sts.multivariate_normal([x, y], cov).pdf(XY)
return X, Y, PDF
def mvpdfs(xs, ys, xlim, ylim, radius=None, velocity=None, scale=None, theta=None):
PDFs = []
for i,(x,y) in enumerate(zip(xs,ys)):
kwargs = {
'xlim': xlim,
'ylim': ylim
}
X, Y, PDF = mvpdf(x, y,**kwargs)
PDFs.append(PDF)
return X, Y, np.sum(PDFs, axis=0)
fig, ax = plt.subplots(figsize = (10,4))
def plotmvs(df, xlim=None, ylim=None, fig=fig, ax=ax):
if xlim is None: xlim = datalimits(df['X'])
if ylim is None: ylim = datalimits(df['Y'])
PDFs = []
for (group,gdf),color in zip(df.groupby('group'), ('red', 'blue')):
#Animate this scatter
ax.plot(*gdf[['X','Y']].values.T, '.', c=color, alpha = 0.5)
kwargs = {
'xlim': xlim,
'ylim': ylim
}
X, Y, PDF = mvpdfs(gdf['X'].values, gdf['Y'].values, **kwargs)
PDFs.append(PDF)
PDF = PDFs[0] - PDFs[1]
normPDF = PDF - PDF.min()
normPDF = normPDF/normPDF.max()
#Animate this contour
cfs = ax.contourf(X, Y, normPDF, levels=100, cmap='jet')
return fig, ax
n = 10
time = [1]
d = ({
'A1_Y' : [10,20,15,20,25,40,50,60,61,65],
'A1_X' : [15,10,15,20,25,25,30,40,60,61],
'A2_Y' : [10,13,17,10,20,24,29,30,33,40],
'A2_X' : [10,13,15,17,18,19,20,21,26,30],
'A3_Y' : [11,12,15,17,19,20,22,25,27,30],
'A3_X' : [15,18,20,21,22,28,30,32,35,40],
'A4_Y' : [15,20,15,20,25,40,50,60,61,65],
'A4_X' : [16,20,15,30,45,30,40,10,11,15],
'B1_Y' : [18,10,11,13,18,10,30,40,31,45],
'B1_X' : [17,20,15,10,25,20,10,12,14,25],
'B2_Y' : [13,10,14,20,21,12,30,20,11,35],
'B2_X' : [12,20,16,22,15,20,10,20,16,15],
'B3_Y' : [15,20,15,20,25,10,20,10,15,25],
'B3_X' : [18,15,13,20,21,10,20,10,11,15],
'B4_Y' : [19,12,15,18,14,19,13,12,11,18],
'B4_X' : [20,10,12,18,17,15,13,14,19,13],
})
tuples = [((t, k.split('_')[0][0], int(k.split('_')[0][1:]), k.split('_')[1]), v[i]) for k,v in d.items() for i,t in enumerate(time)]
df = pd.Series(dict(tuples)).unstack(-1)
df.index.names = ['time', 'group', 'id']
for time,tdf in df.groupby('time'):
plotmvs(tdf)
I essentially want to animate this code by iterating over each row of xy coordinates.
Here's a very quick and dirty modification of the OP's code, fixing the scatter animation and adding (a form of) contour animation.
Basically, you start by creating the artists for your animation (in this case Line2D objects, as returned by plot()). Subsequently, you create an update function (and, optionally, an initialization function). In that function, you update the existing artists. I think the example in the matplotlib docs explains it all.
In this case, I modified the OP's plotmvs function to be used as the update function (instead of the OP's proposed animate function).
The QuadContourSet returned by contourf (i.e. your cfs) cannot be used as an artist in itself, but you can make it work using cfs.collections (props to this SO answer). However, you still need to create a new contour plot and remove the old one, instead of just updating the contour data. Personally I would prefer a lower level approach: try to get the contour-data without calling contourf, then initialize and update the contour lines just like you do for the scatter.
Nevertheless, the approach above is implemented in the OP's code below (just copy, paste, and run):
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as sts
from matplotlib.animation import FuncAnimation
# quick and dirty override of datalimits(), to get a fixed contour-plot size
DATA_LIMITS = [0, 70]
def datalimits(*data, pad=.15):
# dmin,dmax = min(d.min() for d in data), max(d.max() for d in data)
# spad = pad*(dmax - dmin)
return DATA_LIMITS # dmin - spad, dmax + spad
def rot(theta):
theta = np.deg2rad(theta)
return np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
def getcov(radius=1, scale=1, theta=0):
cov = np.array([
[radius*(scale + 1), 0],
[0, radius/(scale + 1)]
])
r = rot(theta)
return r # cov # r.T
def mvpdf(x, y, xlim, ylim, radius=1, velocity=0, scale=0, theta=0):
X,Y = np.meshgrid(np.linspace(*xlim), np.linspace(*ylim))
XY = np.stack([X, Y], 2)
x,y = rot(theta) # (velocity/2, 0) + (x, y)
cov = getcov(radius=radius, scale=scale, theta=theta)
PDF = sts.multivariate_normal([x, y], cov).pdf(XY)
return X, Y, PDF
def mvpdfs(xs, ys, xlim, ylim, radius=None, velocity=None, scale=None, theta=None):
PDFs = []
for i,(x,y) in enumerate(zip(xs,ys)):
kwargs = {
'xlim': xlim,
'ylim': ylim
}
X, Y, PDF = mvpdf(x, y,**kwargs)
PDFs.append(PDF)
return X, Y, np.sum(PDFs, axis=0)
fig, ax = plt.subplots(figsize = (10,4))
ax.set_xlim(DATA_LIMITS)
ax.set_ylim(DATA_LIMITS)
# Initialize empty lines for the scatter (increased marker size to make them more visible)
line_a, = ax.plot([], [], '.', c='red', alpha = 0.5, markersize=20, animated=True)
line_b, = ax.plot([], [], '.', c='blue', alpha = 0.5, markersize=20, animated=True)
cfs = None
# Modify the plotmvs function so it updates the lines
# (might as well rename the function to "update")
def plotmvs(tdf, xlim=None, ylim=None):
global cfs # as noted: quick and dirty...
if cfs:
for tp in cfs.collections:
# Remove the existing contours
tp.remove()
# Get the data frame for time t
df = tdf[1]
if xlim is None: xlim = datalimits(df['X'])
if ylim is None: ylim = datalimits(df['Y'])
PDFs = []
for (group, gdf), group_line in zip(df.groupby('group'), (line_a, line_b)):
#Animate this scatter
#ax.plot(*gdf[['X','Y']].values.T, '.', c=color, alpha = 0.5)
# Update the scatter line data
group_line.set_data(*gdf[['X','Y']].values.T)
kwargs = {
'xlim': xlim,
'ylim': ylim
}
X, Y, PDF = mvpdfs(gdf['X'].values, gdf['Y'].values, **kwargs)
PDFs.append(PDF)
PDF = PDFs[0] - PDFs[1]
normPDF = PDF - PDF.min()
normPDF = normPDF / normPDF.max()
# Plot a new contour
cfs = ax.contourf(X, Y, normPDF, levels=100, cmap='jet')
# Return the artists (the trick is to return cfs.collections instead of cfs)
return cfs.collections + [line_a, line_b]
n = 10
time = range(n) # assuming n represents the length of the time vector...
d = ({
'A1_Y' : [10,20,15,20,25,40,50,60,61,65],
'A1_X' : [15,10,15,20,25,25,30,40,60,61],
'A2_Y' : [10,13,17,10,20,24,29,30,33,40],
'A2_X' : [10,13,15,17,18,19,20,21,26,30],
'A3_Y' : [11,12,15,17,19,20,22,25,27,30],
'A3_X' : [15,18,20,21,22,28,30,32,35,40],
'A4_Y' : [15,20,15,20,25,40,50,60,61,65],
'A4_X' : [16,20,15,30,45,30,40,10,11,15],
'B1_Y' : [18,10,11,13,18,10,30,40,31,45],
'B1_X' : [17,20,15,10,25,20,10,12,14,25],
'B2_Y' : [13,10,14,20,21,12,30,20,11,35],
'B2_X' : [12,20,16,22,15,20,10,20,16,15],
'B3_Y' : [15,20,15,20,25,10,20,10,15,25],
'B3_X' : [18,15,13,20,21,10,20,10,11,15],
'B4_Y' : [19,12,15,18,14,19,13,12,11,18],
'B4_X' : [20,10,12,18,17,15,13,14,19,13],
})
tuples = [((t, k.split('_')[0][0], int(k.split('_')[0][1:]), k.split('_')[1]), v[i])
for k,v in d.items() for i,t in enumerate(time)]
df = pd.Series(dict(tuples)).unstack(-1)
df.index.names = ['time', 'group', 'id']
# Use the modified plotmvs as the update function, and supply the data frames
interval_ms = 200
delay_ms = 1000
ani = FuncAnimation(fig, plotmvs, frames=df.groupby('time'),
blit=True, interval=interval_ms, repeat_delay=delay_ms)
# Start the animation
plt.show()