Modifying a statsmodels graph - python

I am following the statsmodels documentation here:
https://www.statsmodels.org/stable/vector_ar.html
I get to the part at the middle of the page that says:
irf.plot(orth=False)
which produces the following graph for my data:
I need to modify the elements of the graph. E.g., I need to apply tight_layout and also decrease the y-tick sizes so that they don't get into the graphs to their left.
The documentation talks about passing "subplot plotting funcions" in to the subplot argument of irf.plot(). But when I try something like:
irf.plot(subplot_params = {'fontsize': 8, 'figsize' : (100, 100), 'tight_layout': True})
only the fontsize parameter works. I also tried passing these parameters to the 'plot_params' argument but of no avail.
So, my question is how can I access other parameters of this irf.plot, especially the figsize and ytick sizes? I also need to force it to print a grid, as well as all values on the x axis (1, 2, 3, 4, ..., 10)
Is there any way I can create a blank plot using the fig, ax = plt.subplots() way and then create the irf.plot on that figure?

Looks like the function returns a matplotlib.figure:
Try doing this:
fig = irf.plot(orth=False,..)
fig.tight_layout()
fig.set_figheight(100)
fig.set_figwidth(100)
If I run it with this example, it works:
import numpy as np
import pandas
import statsmodels.api as sm
from statsmodels.tsa.api import VAR
mdata = sm.datasets.macrodata.load_pandas().data
dates = mdata[['year', 'quarter']].astype(int).astype(str)
quarterly = dates["year"] + "Q" + dates["quarter"]
from statsmodels.tsa.base.datetools import dates_from_str
quarterly = dates_from_str(quarterly)
mdata = mdata[['realgdp','realcons','realinv']]
mdata.index = pandas.DatetimeIndex(quarterly)
data = np.log(mdata).diff().dropna()
model = VAR(data)
results = model.fit(maxlags=15, ic='aic')
irf = results.irf(10)
fig = irf.plot(orth=False)
fig.tight_layout()
fig.set_figheight(30)
fig.set_figwidth(30)

Related

Displot dips between whole numbers

I am trying to plot a density curve with seaborn using age of vehicles.
My density curve has dips between the whole numbers while my age values are all whole number.
Can't seem to find anything related to this issue so I thought I would try my luck here, any input is appreciated.
My fix currently is just using a histogram with a larger bin but would like to get this working with a density plot.
Thanks!
In seaborn.displot you are passing the kind = 'kde' parameter, in order to get a continuous corve. However, this parameter triggers the Kernel Density Estimation computation, which compute values for all number, included non integers ones.
Instead, you need to tune seaborn.histplot in order to get a continuous step curve with element and fill parameters (I create a fake dataframe just to draw a plot, since you didn't provide your data):
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
N = 10000
df = pd.DataFrame({'age': np.random.poisson(lam = 4, size = N)})
df['age'] = df['age'] + 1
fig, ax = plt.subplots(1, 2, figsize = (8, 4))
sns.histplot(ax = ax[0], data = df, bins = np.arange(0.5, df['age'].max() + 1, 1))
sns.histplot(ax = ax[1], data = df, bins = np.arange(0.5, df['age'].max() + 1, 1), element = 'step', fill = False)
ax[0].set_xticks(range(1, 14))
ax[1].set_xticks(range(1, 14))
plt.show()
As a comparison, here the seaborn.displot on the same dataframe, passing kind = 'kde' parameter:

Automate matplotlib figure dimension based on the title

I have a set of three variables for which I would like to calculate a boxplot. The three variables share the same title, and have different x label, and I want the three of them to be created in the same figure.
Look to the following example (with fake data):
import numpy as np
import matplotlib.pyplot as plt
data = dict(var1=np.random.normal(0, 1, 1000), var2=np.random.normal(0, 2, 1000), var3=np.random.normal(1, 2, 1000))
var_title = 'Really really long overlapping title'
fig = plt.figure()
for i in range(len(data)):
plt.subplot(1, 3, i + 1)
plt.boxplot(data[data.keys()[i]])
plt.title(var_title)
plt.show()
This code generates the following figure:
Now, what I want is either to set just one title over the three subplots (since the title is the same) or get python to automatically redimension the figure so that the title fits and can be read.
I am asking for this because this figure creation is part of a batch process in which the figures are saved automatically and used on the generation of PDF documents, so I cannot be involved on changing the figure dimensions one at a time.
In order to create a title for multiple subplots you can use Figure.suptitle instead of fig.title and format font size and such as specified there.
So your code will look like:
import numpy as np
import matplotlib.pyplot as plt
data = dict(var1=np.random.normal(0, 1, 1000),
var2=np.random.normal(0, 2, 1000), var3=np.random.normal(1, 2, 1000))
var_title = 'Really really long overlapping title'
fig = plt.figure()
fig.suptitle(var_title)
for i in range(len(data)):
plt.subplot(1, 3, i + 1)
plt.boxplot(data[data.keys()[i]])
plt.show()
Question is also answered here by orbeckst.
Based on the comment of #ImportanceOfBeingErnest you can do this:
import numpy as np
import matplotlib.pyplot as plt
data = dict(var1=np.random.normal(0, 1, 1000), var2=np.random.normal(0, 2, 1000), var3=np.random.normal(1, 2, 1000))
var_title = 'Really really long overlapping title'
f, ax = plt.subplots(1, 3)
for i in range(len(data)):
ax[i].boxplot(data[list(data.keys())[i]])
f.suptitle(var_title)
plt.show()

Change Error Bar Markers (Caplines) in Pandas Bar Plot

so I am plotting error bar of pandas dataframe. Now the error bar has a weird arrow at the top, but what I want is a horizontal line. For example, a figure like this:
But now my error bar ends with arrow instead of a horinzontal line.
Here is the code i used to generate it:
plot = meansum.plot(
kind="bar",
yerr=stdsum,
colormap="OrRd_r",
edgecolor="black",
grid=False,
figsize=(8, 2),
ax=ax,
position=0.45,
error_kw=dict(ecolor="black", elinewidth=0.5, lolims=True, marker="o"),
width=0.8,
)
So what should I change to make the error become the one I want. Thx.
Using plt.errorbar from matplotlib makes it easier as it returns several objects including the caplines which contain the marker you want to change (the arrow which is automatically used when lolims is set to True, see docs).
Using pandas, you just need to dig the correct line in the children of plot and change its marker:
import pandas as pd
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
df = pd.DataFrame({"val":[1,2,3,4],"error":[.4,.3,.6,.9]})
meansum = df["val"]
stdsum = df["error"]
plot = meansum.plot(kind='bar',yerr=stdsum,colormap='OrRd_r',edgecolor='black',grid=False,figsize=8,2),ax=ax,position=0.45,error_kw=dict(ecolor='black',elinewidth=0.5, lolims=True),width=0.8)
for ch in plot.get_children():
if str(ch).startswith('Line2D'): # this is silly, but it appears that the first Line in the children are the caplines...
ch.set_marker('_')
ch.set_markersize(10) # to change its size
break
plt.show()
The result looks like:
Just don't set lolim = True and you are good to go, an example with sample data:
import pandas as pd
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
df = pd.DataFrame({"val":[1,2,3,4],"error":[.4,.3,.6,.9]})
meansum = df["val"]
stdsum = df["error"]
plot = meansum.plot(kind='bar',yerr=stdsum,colormap='OrRd_r',edgecolor='black',grid=False,figsize=(8,2),ax=ax,position=0.45,error_kw=dict(ecolor='black',elinewidth=0.5),width=0.8)
plt.show()

How to assign a plot to a variable and use the variable as the return value in a Python function

I am creating two Python scripts to produce some plots for a technical report. In the first script I am defining functions that produce plots from raw data on my hard-disk. Each function produces one specific kind of plot that I need. The second script is more like a batch file which is supposed to loop around those functions and store the produced plots on my hard-disk.
What I need is a way to return a plot in Python. So basically I want to do this:
fig = some_function_that_returns_a_plot(args)
fig.savefig('plot_name')
But what I do not know is how to make a plot a variable that I can return. Is this possible? Is so, how?
You can define your plotting functions like
import numpy as np
import matplotlib.pyplot as plt
# an example graph type
def fig_barh(ylabels, xvalues, title=''):
# create a new figure
fig = plt.figure()
# plot to it
yvalues = 0.1 + np.arange(len(ylabels))
plt.barh(yvalues, xvalues, figure=fig)
yvalues += 0.4
plt.yticks(yvalues, ylabels, figure=fig)
if title:
plt.title(title, figure=fig)
# return it
return fig
then use them like
from matplotlib.backends.backend_pdf import PdfPages
def write_pdf(fname, figures):
doc = PdfPages(fname)
for fig in figures:
fig.savefig(doc, format='pdf')
doc.close()
def main():
a = fig_barh(['a','b','c'], [1, 2, 3], 'Test #1')
b = fig_barh(['x','y','z'], [5, 3, 1], 'Test #2')
write_pdf('test.pdf', [a, b])
if __name__=="__main__":
main()
If you don't want the picture to be displayed and only get a variable in return, then you can try the following (with some additional stuff to remove axis):
def myplot(t,x):
fig = Figure(figsize=(2,1), dpi=80)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot()
ax.fill_between(t,x)
ax.autoscale(tight=True)
ax.axis('off')
canvas.draw()
buf = canvas.buffer_rgba()
X = np.asarray(buf)
return X
The returned variable X can be used with OpenCV for example and do a
cv2.imshow('',X)
These import must be included:
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
The currently accepted answer didn't work for me as such, as I was using scipy.stats.probplot() to plot. I used matplotlib.pyplot.gca() to access an Axes instance directly instead:
"""
For my plotting ideas, see:
https://pythonfordatascience.org/independent-t-test-python/
For the dataset, see:
https://github.com/Opensourcefordatascience/Data-sets
"""
# Import modules.
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
from tempfile import gettempdir
from os import path
from slugify import slugify
# Define plot func.
def get_plots(df):
# plt.figure(): Create a new P-P plot. If we're inside a loop, and want
# a new plot for every iteration, this is important!
plt.figure()
stats.probplot(diff, plot=plt)
plt.title('Sepal Width P-P Plot')
pp_p = plt.gca() # Assign an Axes instance of the plot.
# Plot histogram. This uses pandas.DataFrame.plot(), which returns
# an instance of the Axes directly.
hist_p = df.plot(kind = 'hist', title = 'Sepal Width Histogram Plot',
figure=plt.figure()) # Create a new plot again.
return pp_p, hist_p
# Import raw data.
df = pd.read_csv('https://raw.githubusercontent.com/'
'Opensourcefordatascience/Data-sets/master//Iris_Data.csv')
# Subset the dataset.
setosa = df[(df['species'] == 'Iris-setosa')]
setosa.reset_index(inplace= True)
versicolor = df[(df['species'] == 'Iris-versicolor')]
versicolor.reset_index(inplace= True)
# Calculate a variable for analysis.
diff = setosa['sepal_width'] - versicolor['sepal_width']
# Create plots, save each of them to a temp file, and show them afterwards.
# As they're just Axes instances, we need to call get_figure() at first.
for plot in get_plots(diff):
outfn = path.join(gettempdir(), slugify(plot.title.get_text()) + '.png')
print('Saving a plot to "' + outfn + '".')
plot.get_figure().savefig(outfn)
plot.get_figure().show()

what is the corresponding matplotlib code of this matlab code

I'm trying to go away from matlab and use python + matplotlib instead. However, I haven't really figured out what the matplotlib equivalent of matlab 'handles' is. So here's some matlab code where I return the handles so that I can change certain properties. What is the exact equivalent of this code using matplotlib? I very often use the 'Tag' property of handles in matlab and use 'findobj' with it. Can this be done with matplotlib as well?
% create figure and return figure handle
h = figure();
% add a plot and tag it so we can find the handle later
plot(1:10, 1:10, 'Tag', 'dummy')
% add a legend
my_legend = legend('a line')
% change figure name
set(h, 'name', 'myfigure')
% find current axes
my_axis = gca();
% change xlimits
set(my_axis, 'XLim', [0 5])
% find the plot object generated above and modify YData
set(findobj('Tag', 'dummy'), 'YData', repmat(10, 1, 10))
There is a findobj method is matplotlib too:
import matplotlib.pyplot as plt
import numpy as np
h = plt.figure()
plt.plot(range(1,11), range(1,11), gid='dummy')
my_legend = plt.legend(['a line'])
plt.title('myfigure') # not sure if this is the same as set(h, 'name', 'myfigure')
my_axis = plt.gca()
my_axis.set_xlim(0,5)
for p in set(h.findobj(lambda x: x.get_gid()=='dummy')):
p.set_ydata(np.ones(10)*10.0)
plt.show()
Note that the gid parameter in plt.plot is usually used by matplotlib (only) when the backend is set to 'svg'. It use the gid as the id attribute to some grouping elements (like line2d, patch, text).
I have not used matlab but I think this is what you want
import matplotlib
import matplotlib.pyplot as plt
x = [1,3,4,5,6]
y = [1,9,16,25,36]
fig = plt.figure()
ax = fig.add_subplot(111) # add a plot
ax.set_title('y = x^2')
line1, = ax.plot(x, y, 'o-') #x1,y1 are lists(equal size)
line1.set_ydata(y2) #Use this to modify Ydata
plt.show()
Of course, this is just a basic plot, there is more to it.Go though this to find the graph you want and view its source code.
# create figure and return figure handle
h = figure()
# add a plot but tagging like matlab is not available here. But you can
# set one of the attributes to find it later. url seems harmless to modify.
# plot() returns a list of Line2D instances which you can store in a variable
p = plot(arange(1,11), arange(1,11), url='my_tag')
# add a legend
my_legend = legend(p,('a line',))
# you could also do
# p = plot(arange(1,11), arange(1,11), label='a line', url='my_tag')
# legend()
# or
# p[0].set_label('a line')
# legend()
# change figure name: not sure what this is for.
# set(h, 'name', 'myfigure')
# find current axes
my_axis = gca()
# change xlimits
my_axis.set_xlim(0, 5)
# You could compress the above two lines of code into:
# xlim(start, end)
# find the plot object generated above and modify YData
# findobj in matplotlib needs you to write a boolean function to
# match selection criteria.
# Here we use a lambda function to return only Line2D objects
# with the url property set to 'my_tag'
q = h.findobj(lambda x: isinstance(x, Line2D) and x.get_url() == 'my_tag')
# findobj returns duplicate objects in the list. We can take the first entry.
q[0].set_ydata(ones(10)*10.0)
# now refresh the figure
draw()

Categories

Resources