Matplotlib - Reversing label and line in legend - python

I'm trying to reverse the label and key columns in a matplotlib legend and I'm really struggling to even know where to start.
In a normal matplotlib legend the pattern is key, then label, like in the example below where it goes key (blue line), then label (First Line):
To match our company plotting style we plot things the reverse, i.e., label first then key (see the legend below). So the plot above would be First line, then the key (blue line).
The additional complication is that the keys should be in one column (so the align in one vertical column) regardless of the length of the label.

Well, there is the keyword markerfirst for this.
from matplotlib import pyplot as plt
import numpy as np
np.random.seed(1234)
n=7
fig, ax = plt.subplots()
ax.plot(np.arange(n), np.random.random(n), label="ABCDEF")
ax.plot(np.arange(n), np.random.random(n), label="G")
ax.legend(markerfirst=False)
plt.show()
Sample output

I would be tempted to write a standalone function that ignores ax.legend() entirely and instead draws a white box, the labels, and the markers where you need them. All the coordinates would be expressed in ax coordinates via transform=ax.transAxes to ensure a proper positioning and replace the locator keyword of ax.legend().
The following code will automatically cram all the artists found on the ax in the legend box boundaries that you defined. You might need to adjust the "padding" a bit.
Note that for some reason it does not work with lines of width 0 that only use a marker, but it shouldn't be an issue considering your question.
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
# Dummy data.
X = np.linspace(-5, +5, 100)
Y1 = np.sin(X)
Y2 = np.cos(X/3)
Y3 = Y2-Y1
Y4 = Y3*Y1
ax.plot(Y1, label="Y1")
ax.plot(Y2, label="Y2")
ax.plot(Y3, label="Y3", linestyle="--")
ax.plot(Y4, label="Y4", marker="d", markersize=4, linewidth=0)
fig.show()
def custom_legend(ax):
"""Adds a custom legend to the provided ax. Its labels are aligned
on the left and the markers on the right. Both are taken automatically
from the ax."""
handles, labels = ax.get_legend_handles_labels()
# Boundaries of your custom legend.
xmin, xmax = 0.7, 0.9
ymin, ymax = 0.5, 0.9
N = len(handles)
width = xmax-xmin
height = ymax-ymin
dy = height/N
r = plt.Rectangle((xmin, ymin),
width=width,
height=height,
transform=ax.transAxes,
fill=True,
facecolor="white",
edgecolor="black",
zorder=1000)
ax.add_artist(r)
# Grab the tiny lines that would be created by a call to `ax.legend()` so
# that we don't have to retrieve all the attributes ourselves.
legend = ax.legend()
handles = legend.legendHandles.copy()
legend.remove()
for n, (handle, label) in enumerate(zip(handles, labels)):
# Place the labels on the left of the legend box.
x = xmin + 0.01
y = ymax - n*dy - 0.05
ax.text(x, y, label, transform=ax.transAxes, va="center", ha="left", zorder=1001)
# Move a bit to the right and place the line artists.
x0 = (xmax - 1/2*width)
x1 = (xmax - 1/8*width)
y0, y1 = (y, y)
handle.set_data(((x0, x1), (y0, y1)))
handle.set_transform(ax.transAxes)
handle.set_zorder(1002)
ax.add_artist(handle)
custom_legend(ax)
fig.canvas.draw()

Related

Graphing multiple lines for axvline [duplicate]

Given a plot of a signal in time representation, how can I draw lines marking the corresponding time index?
Specifically, given a signal plot with a time index ranging from 0 to 2.6 (seconds), I want to draw vertical red lines indicating the corresponding time index for the list [0.22058956, 0.33088437, 2.20589566]. How can I do it?
The standard way to add vertical lines that will cover your entire plot window without you having to specify their actual height is plt.axvline
import matplotlib.pyplot as plt
plt.axvline(x=0.22058956)
plt.axvline(x=0.33088437)
plt.axvline(x=2.20589566)
OR
xcoords = [0.22058956, 0.33088437, 2.20589566]
for xc in xcoords:
plt.axvline(x=xc)
You can use many of the keywords available for other plot commands (e.g. color, linestyle, linewidth ...). You can pass in keyword arguments ymin and ymax if you like in axes corrdinates (e.g. ymin=0.25, ymax=0.75 will cover the middle half of the plot). There are corresponding functions for horizontal lines (axhline) and rectangles (axvspan).
matplotlib.pyplot.vlines vs. matplotlib.pyplot.axvline
These methods are applicable to plots generated with seaborn and pandas.DataFrame.plot, which both use matplotlib.
The difference is that vlines accepts one or more locations for x, while axvline permits one location.
Single location: x=37.
Multiple locations: x=[37, 38, 39].
vlines takes ymin and ymax as a position on the y-axis, while axvline takes ymin and ymax as a percentage of the y-axis range.
When passing multiple lines to vlines, pass a list to ymin and ymax.
Also matplotlib.axes.Axes.vlines and matplotlib.axes.Axes.axvline for the object-oriented API.
If you're plotting a figure with something like fig, ax = plt.subplots(), then replace plt.vlines or plt.axvline with ax.vlines or ax.axvline, respectively.
See this answer for horizontal lines with .hlines.
import numpy as np
import matplotlib.pyplot as plt
xs = np.linspace(1, 21, 200)
plt.figure(figsize=(10, 7))
# only one line may be specified; full height
plt.axvline(x=36, color='b', label='axvline - full height')
# only one line may be specified; ymin & ymax specified as a percentage of y-range
plt.axvline(x=36.25, ymin=0.05, ymax=0.95, color='b', label='axvline - % of full height')
# multiple lines all full height
plt.vlines(x=[37, 37.25, 37.5], ymin=0, ymax=len(xs), colors='purple', ls='--', lw=2, label='vline_multiple - full height')
# multiple lines with varying ymin and ymax
plt.vlines(x=[38, 38.25, 38.5], ymin=[0, 25, 75], ymax=[200, 175, 150], colors='teal', ls='--', lw=2, label='vline_multiple - partial height')
# single vline with full ymin and ymax
plt.vlines(x=39, ymin=0, ymax=len(xs), colors='green', ls=':', lw=2, label='vline_single - full height')
# single vline with specific ymin and ymax
plt.vlines(x=39.25, ymin=25, ymax=150, colors='green', ls=':', lw=2, label='vline_single - partial height')
# place the legend outside
plt.legend(bbox_to_anchor=(1.0, 1), loc='upper left')
plt.show()
Seaborn axes-level plot
import seaborn as sns
# sample data
fmri = sns.load_dataset("fmri")
# x index for max y values for stim and cue
c_max, s_max = fmri.pivot_table(index='timepoint', columns='event', values='signal', aggfunc='mean').idxmax()
# plot
g = sns.lineplot(data=fmri, x="timepoint", y="signal", hue="event")
# y min and max
ymin, ymax = g.get_ylim()
# vertical lines
g.vlines(x=[c_max, s_max], ymin=ymin, ymax=ymax, colors=['tab:orange', 'tab:blue'], ls='--', lw=2)
Seaborn figure-level plot
Each axes must be iterated through.
import seaborn as sns
# sample data
fmri = sns.load_dataset("fmri")
# used to get the index values (x) for max y for each event in each region
fpt = fmri.pivot_table(index=['region', 'timepoint'], columns='event', values='signal', aggfunc='mean')
# plot
g = sns.relplot(data=fmri, x="timepoint", y="signal", col="region", hue="event", kind="line")
# iterate through the axes
for ax in g.axes.flat:
# get y min and max
ymin, ymax = ax.get_ylim()
# extract the region from the title for use in selecting the index of fpt
region = ax.get_title().split(' = ')[1]
# get x values for max event
c_max, s_max = fpt.loc[region].idxmax()
# add vertical lines
ax.vlines(x=[c_max, s_max], ymin=ymin, ymax=ymax, colors=['tab:orange', 'tab:blue'], ls='--', lw=2, alpha=0.5)
For 'region = frontal' the maximum value of both events occurs at 5.
Barplot and Histograms
Note that bar plot tick locations have a zero-based index, regardless of the axis tick labels, so select x based on the bar index, not the tick label.
ax.get_xticklabels() will show the locations and labels.
import pandas as pd
import seaborn as sns
# load data
tips = sns.load_dataset('tips')
# histogram
ax = tips.plot(kind='hist', y='total_bill', bins=30, ec='k', title='Histogram with Vertical Line')
_ = ax.vlines(x=16.5, ymin=0, ymax=30, colors='r')
# barplot
ax = tips.loc[5:25, ['total_bill', 'tip']].plot(kind='bar', figsize=(15, 4), title='Barplot with Vertical Lines', rot=0)
_ = ax.vlines(x=[0, 17], ymin=0, ymax=45, colors='r')
Time Series Axis
The dates in the dataframe to be the x-axis must be a datetime dtype. If the column or index is not the correct type, it must be converted with pd.to_datetime.
If an array or list of dates is being used, refer to Converting numpy array of strings to datetime or Convert datetime list into date python, respectively.
x will accept a date like '2020-09-24' or datetime(2020, 9, 2).
import pandas_datareader as web # conda or pip install this; not part of pandas
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
# get test data; this data is downloaded with the Date column in the index as a datetime dtype
df = web.DataReader('^gspc', data_source='yahoo', start='2020-09-01', end='2020-09-28').iloc[:, :2]
# display(df.head(2))
High Low
Date
2020-09-01 3528.030029 3494.600098
2020-09-02 3588.110107 3535.229980
# plot dataframe; the index is a datetime index
ax = df.plot(figsize=(9, 6), title='S&P 500', ylabel='Price')
# add vertical lines
ax.vlines(x=[datetime(2020, 9, 2), '2020-09-24'], ymin=3200, ymax=3600, color='r', label='test lines')
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
plt.show()
For multiple lines
xposition = [0.3, 0.4, 0.45]
for xc in xposition:
plt.axvline(x=xc, color='k', linestyle='--')
To add a legend and/or colors to some vertical lines, then use this:
import matplotlib.pyplot as plt
# x coordinates for the lines
xcoords = [0.1, 0.3, 0.5]
# colors for the lines
colors = ['r','k','b']
for xc,c in zip(xcoords,colors):
plt.axvline(x=xc, label='line at x = {}'.format(xc), c=c)
plt.legend()
plt.show()
Results
Calling axvline in a loop, as others have suggested, works, but it can be inconvenient because
Each line is a separate plot object, which causes things to be very slow when you have many lines.
When you create the legend each line has a new entry, which may not be what you want.
Instead, you can use the following convenience functions which create all the lines as a single plot object:
import matplotlib.pyplot as plt
import numpy as np
def axhlines(ys, ax=None, lims=None, **plot_kwargs):
"""
Draw horizontal lines across plot
:param ys: A scalar, list, or 1D array of vertical offsets
:param ax: The axis (or none to use gca)
:param lims: Optionally the (xmin, xmax) of the lines
:param plot_kwargs: Keyword arguments to be passed to plot
:return: The plot object corresponding to the lines.
"""
if ax is None:
ax = plt.gca()
ys = np.array((ys, ) if np.isscalar(ys) else ys, copy=False)
if lims is None:
lims = ax.get_xlim()
y_points = np.repeat(ys[:, None], repeats=3, axis=1).flatten()
x_points = np.repeat(np.array(lims + (np.nan, ))[None, :], repeats=len(ys), axis=0).flatten()
plot = ax.plot(x_points, y_points, scalex = False, **plot_kwargs)
return plot
def axvlines(xs, ax=None, lims=None, **plot_kwargs):
"""
Draw vertical lines on plot
:param xs: A scalar, list, or 1D array of horizontal offsets
:param ax: The axis (or none to use gca)
:param lims: Optionally the (ymin, ymax) of the lines
:param plot_kwargs: Keyword arguments to be passed to plot
:return: The plot object corresponding to the lines.
"""
if ax is None:
ax = plt.gca()
xs = np.array((xs, ) if np.isscalar(xs) else xs, copy=False)
if lims is None:
lims = ax.get_ylim()
x_points = np.repeat(xs[:, None], repeats=3, axis=1).flatten()
y_points = np.repeat(np.array(lims + (np.nan, ))[None, :], repeats=len(xs), axis=0).flatten()
plot = ax.plot(x_points, y_points, scaley = False, **plot_kwargs)
return plot
In addition to the plt.axvline and plt.plot((x1, x2), (y1, y2)) or plt.plot([x1, x2], [y1, y2]) as provided in the answers above, one can also use
plt.vlines(x_pos, ymin=y1, ymax=y2)
to plot a vertical line at x_pos spanning from y1 to y2 where the values y1 and y2 are in absolute data coordinates.

Write a text inside a subplot

I'm working on this plot:
I need to write something inside the first plot, between the red and the black lines, I tried with ax1.text() but it shows the text between the two plots and not inside the first one. How can I do that?
The plot was set out as such:
fig, (ax1,ax2) = plt.subplots(nrows=2, ncols=1, figsize = (12,7), tight_layout = True)
Thank you in advance for your help!
Without more code details, it's quite hard to guess what is wrong.
The matplotlib.axes.Axes.text works well to show text box on subplots. I encourage you to have a look at the documentation (arguments...) and try by yourself.
The text location is based on the 2 followings arguments:
transform=ax.transAxes: indicates that the coordinates are given relative to the axes bounding box, with (0, 0) being the lower left of the axes and (1, 1) the upper right.
text(x, y,...): where x, y are the position to place the text. The coordinate system can be changed using the below parameter transform.
Here is an example:
# import modules
import matplotlib.pyplot as plt
import numpy as np
# Create random data
x = np.arange(0,20)
y1 = np.random.randint(0,10, 20)
y2 = np.random.randint(0,10, 20) + 15
# Create figure
fig, (ax1,ax2) = plt.subplots(nrows=2, ncols=1, figsize = (12,7), tight_layout = True)
# Add subplots
ax1.plot(x, y1)
ax1.plot(x, y2)
ax2.plot(x, y1)
ax2.plot(x, y2)
# Show texts
ax1.text(0.1, 0.5, 'Begin text', horizontalalignment='center', verticalalignment='center', transform=ax1.transAxes)
ax2.text(0.9, 0.5, 'End text', horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes)
plt.show()
output

matplotlib - autosize of text according to shape size

I'm adding a text inside a shape by:
ax.text(x,y,'text', ha='center', va='center',bbox=dict(boxstyle='circle', fc="w", ec="k"),fontsize=10) (ax is AxesSubplot)
The problem is that I couldn't make the circle size constant while changing the string length. I want the text size adjust to the circle size and not the other way around.
The circle is even completely gone if the string is an empty one.
The only bypass to the problem I had found is dynamically to set the fontsize param according to the len of the string, but that's too ugly and not still the circle size is not completely constant.
EDIT (adding a MVCE):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.text(0.5,0.5,'long_text', ha='center', va='center',bbox=dict(boxstyle='circle', fc="w", ec="k"),fontsize=10)
ax.text(0.3,0.7,'short', ha='center', va='center',bbox=dict(boxstyle='circle', fc="w", ec="k"),fontsize=10)
plt.show()
Trying to make both circles the same size although the string len is different. Currently looks like this:
I have a very dirty and hard-core solution which requires quite deep knowledge of matplotlib. It is not perfect but might give you some ideas how to start.
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
plt.close('all')
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
t1 = ax.text(0.5,0.5,'long_text', ha='center', va='center',fontsize=10)
t2 = ax.text(0.3,0.7,'short', ha='center', va='center', fontsize=10)
t3 = ax.text(0.1,0.7,'super-long-text-that-is-long', ha='center', va='center', fontsize=10)
fig.show()
def text_with_circle(text_obj, axis, color, border=1.5):
# Get the box containing the text
box1 = text_obj.get_window_extent()
# It turned out that what you get from the box is
# in screen pixels, so we need to transform them
# to "data"-coordinates. This is done with the
# transformer-function below
transformer = axis.transData.inverted().transform
# Now transform the corner coordinates of the box
# to data-coordinates
[x0, y0] = transformer([box1.x0, box1.y0])
[x1, y1] = transformer([box1.x1, box1.y1])
# Find the x and y center coordinate
x_center = (x0+x1)/2.
y_center = (y0+y1)/2.
# Find the radius, add some extra to make a nice border around it
r = np.max([(x1-x0)/2., (y1-y0)/2.])*border
# Plot the a circle at the center of the text, with radius r.
circle = Circle((x_center, y_center), r, color=color)
# Add the circle to the axis.
# Redraw the canvas.
return circle
circle1 = text_with_circle(t1, ax, 'g')
ax.add_artist(circle1)
circle2 = text_with_circle(t2, ax, 'r', 5)
ax.add_artist(circle2)
circle3 = text_with_circle(t3, ax, 'y', 1.1)
ax.add_artist(circle3)
fig.canvas.draw()
At the moment you have to run this in ipython, because the figure has to be drawn BEFORE you get_window_extent(). Therefore the fig.show() has to be called AFTER the text is added, but BEFORE the circle can be drawn! Then we can get the coordinates of the text, figures out where the middle is and add a circle around the text with a certain radius. When this is done we redraw the canvas to update with the new circle. Ofcourse you can customize the circle a lot more (edge color, face color, line width, etc), look into the Circle class.
Example of output plot:

How to draw vertical lines on a given plot

Given a plot of a signal in time representation, how can I draw lines marking the corresponding time index?
Specifically, given a signal plot with a time index ranging from 0 to 2.6 (seconds), I want to draw vertical red lines indicating the corresponding time index for the list [0.22058956, 0.33088437, 2.20589566]. How can I do it?
The standard way to add vertical lines that will cover your entire plot window without you having to specify their actual height is plt.axvline
import matplotlib.pyplot as plt
plt.axvline(x=0.22058956)
plt.axvline(x=0.33088437)
plt.axvline(x=2.20589566)
OR
xcoords = [0.22058956, 0.33088437, 2.20589566]
for xc in xcoords:
plt.axvline(x=xc)
You can use many of the keywords available for other plot commands (e.g. color, linestyle, linewidth ...). You can pass in keyword arguments ymin and ymax if you like in axes corrdinates (e.g. ymin=0.25, ymax=0.75 will cover the middle half of the plot). There are corresponding functions for horizontal lines (axhline) and rectangles (axvspan).
matplotlib.pyplot.vlines vs. matplotlib.pyplot.axvline
These methods are applicable to plots generated with seaborn and pandas.DataFrame.plot, which both use matplotlib.
The difference is that vlines accepts one or more locations for x, while axvline permits one location.
Single location: x=37.
Multiple locations: x=[37, 38, 39].
vlines takes ymin and ymax as a position on the y-axis, while axvline takes ymin and ymax as a percentage of the y-axis range.
When passing multiple lines to vlines, pass a list to ymin and ymax.
Also matplotlib.axes.Axes.vlines and matplotlib.axes.Axes.axvline for the object-oriented API.
If you're plotting a figure with something like fig, ax = plt.subplots(), then replace plt.vlines or plt.axvline with ax.vlines or ax.axvline, respectively.
See this answer for horizontal lines with .hlines.
import numpy as np
import matplotlib.pyplot as plt
xs = np.linspace(1, 21, 200)
plt.figure(figsize=(10, 7))
# only one line may be specified; full height
plt.axvline(x=36, color='b', label='axvline - full height')
# only one line may be specified; ymin & ymax specified as a percentage of y-range
plt.axvline(x=36.25, ymin=0.05, ymax=0.95, color='b', label='axvline - % of full height')
# multiple lines all full height
plt.vlines(x=[37, 37.25, 37.5], ymin=0, ymax=len(xs), colors='purple', ls='--', lw=2, label='vline_multiple - full height')
# multiple lines with varying ymin and ymax
plt.vlines(x=[38, 38.25, 38.5], ymin=[0, 25, 75], ymax=[200, 175, 150], colors='teal', ls='--', lw=2, label='vline_multiple - partial height')
# single vline with full ymin and ymax
plt.vlines(x=39, ymin=0, ymax=len(xs), colors='green', ls=':', lw=2, label='vline_single - full height')
# single vline with specific ymin and ymax
plt.vlines(x=39.25, ymin=25, ymax=150, colors='green', ls=':', lw=2, label='vline_single - partial height')
# place the legend outside
plt.legend(bbox_to_anchor=(1.0, 1), loc='upper left')
plt.show()
Seaborn axes-level plot
import seaborn as sns
# sample data
fmri = sns.load_dataset("fmri")
# x index for max y values for stim and cue
c_max, s_max = fmri.pivot_table(index='timepoint', columns='event', values='signal', aggfunc='mean').idxmax()
# plot
g = sns.lineplot(data=fmri, x="timepoint", y="signal", hue="event")
# y min and max
ymin, ymax = g.get_ylim()
# vertical lines
g.vlines(x=[c_max, s_max], ymin=ymin, ymax=ymax, colors=['tab:orange', 'tab:blue'], ls='--', lw=2)
Seaborn figure-level plot
Each axes must be iterated through.
import seaborn as sns
# sample data
fmri = sns.load_dataset("fmri")
# used to get the index values (x) for max y for each event in each region
fpt = fmri.pivot_table(index=['region', 'timepoint'], columns='event', values='signal', aggfunc='mean')
# plot
g = sns.relplot(data=fmri, x="timepoint", y="signal", col="region", hue="event", kind="line")
# iterate through the axes
for ax in g.axes.flat:
# get y min and max
ymin, ymax = ax.get_ylim()
# extract the region from the title for use in selecting the index of fpt
region = ax.get_title().split(' = ')[1]
# get x values for max event
c_max, s_max = fpt.loc[region].idxmax()
# add vertical lines
ax.vlines(x=[c_max, s_max], ymin=ymin, ymax=ymax, colors=['tab:orange', 'tab:blue'], ls='--', lw=2, alpha=0.5)
For 'region = frontal' the maximum value of both events occurs at 5.
Barplot and Histograms
Note that bar plot tick locations have a zero-based index, regardless of the axis tick labels, so select x based on the bar index, not the tick label.
ax.get_xticklabels() will show the locations and labels.
import pandas as pd
import seaborn as sns
# load data
tips = sns.load_dataset('tips')
# histogram
ax = tips.plot(kind='hist', y='total_bill', bins=30, ec='k', title='Histogram with Vertical Line')
_ = ax.vlines(x=16.5, ymin=0, ymax=30, colors='r')
# barplot
ax = tips.loc[5:25, ['total_bill', 'tip']].plot(kind='bar', figsize=(15, 4), title='Barplot with Vertical Lines', rot=0)
_ = ax.vlines(x=[0, 17], ymin=0, ymax=45, colors='r')
Time Series Axis
The dates in the dataframe to be the x-axis must be a datetime dtype. If the column or index is not the correct type, it must be converted with pd.to_datetime.
If an array or list of dates is being used, refer to Converting numpy array of strings to datetime or Convert datetime list into date python, respectively.
x will accept a date like '2020-09-24' or datetime(2020, 9, 2).
import pandas_datareader as web # conda or pip install this; not part of pandas
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
# get test data; this data is downloaded with the Date column in the index as a datetime dtype
df = web.DataReader('^gspc', data_source='yahoo', start='2020-09-01', end='2020-09-28').iloc[:, :2]
# display(df.head(2))
High Low
Date
2020-09-01 3528.030029 3494.600098
2020-09-02 3588.110107 3535.229980
# plot dataframe; the index is a datetime index
ax = df.plot(figsize=(9, 6), title='S&P 500', ylabel='Price')
# add vertical lines
ax.vlines(x=[datetime(2020, 9, 2), '2020-09-24'], ymin=3200, ymax=3600, color='r', label='test lines')
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
plt.show()
For multiple lines
xposition = [0.3, 0.4, 0.45]
for xc in xposition:
plt.axvline(x=xc, color='k', linestyle='--')
To add a legend and/or colors to some vertical lines, then use this:
import matplotlib.pyplot as plt
# x coordinates for the lines
xcoords = [0.1, 0.3, 0.5]
# colors for the lines
colors = ['r','k','b']
for xc,c in zip(xcoords,colors):
plt.axvline(x=xc, label='line at x = {}'.format(xc), c=c)
plt.legend()
plt.show()
Results
Calling axvline in a loop, as others have suggested, works, but it can be inconvenient because
Each line is a separate plot object, which causes things to be very slow when you have many lines.
When you create the legend each line has a new entry, which may not be what you want.
Instead, you can use the following convenience functions which create all the lines as a single plot object:
import matplotlib.pyplot as plt
import numpy as np
def axhlines(ys, ax=None, lims=None, **plot_kwargs):
"""
Draw horizontal lines across plot
:param ys: A scalar, list, or 1D array of vertical offsets
:param ax: The axis (or none to use gca)
:param lims: Optionally the (xmin, xmax) of the lines
:param plot_kwargs: Keyword arguments to be passed to plot
:return: The plot object corresponding to the lines.
"""
if ax is None:
ax = plt.gca()
ys = np.array((ys, ) if np.isscalar(ys) else ys, copy=False)
if lims is None:
lims = ax.get_xlim()
y_points = np.repeat(ys[:, None], repeats=3, axis=1).flatten()
x_points = np.repeat(np.array(lims + (np.nan, ))[None, :], repeats=len(ys), axis=0).flatten()
plot = ax.plot(x_points, y_points, scalex = False, **plot_kwargs)
return plot
def axvlines(xs, ax=None, lims=None, **plot_kwargs):
"""
Draw vertical lines on plot
:param xs: A scalar, list, or 1D array of horizontal offsets
:param ax: The axis (or none to use gca)
:param lims: Optionally the (ymin, ymax) of the lines
:param plot_kwargs: Keyword arguments to be passed to plot
:return: The plot object corresponding to the lines.
"""
if ax is None:
ax = plt.gca()
xs = np.array((xs, ) if np.isscalar(xs) else xs, copy=False)
if lims is None:
lims = ax.get_ylim()
x_points = np.repeat(xs[:, None], repeats=3, axis=1).flatten()
y_points = np.repeat(np.array(lims + (np.nan, ))[None, :], repeats=len(xs), axis=0).flatten()
plot = ax.plot(x_points, y_points, scaley = False, **plot_kwargs)
return plot
In addition to the plt.axvline and plt.plot((x1, x2), (y1, y2)) or plt.plot([x1, x2], [y1, y2]) as provided in the answers above, one can also use
plt.vlines(x_pos, ymin=y1, ymax=y2)
to plot a vertical line at x_pos spanning from y1 to y2 where the values y1 and y2 are in absolute data coordinates.

how to format axis in matplotlib.pyplot and include legend?

Got this function for the graph, want to format axis so that the graph starts from (0,0), also how do I write legends so I can label which line belongs to y1 and which to y2 and label axis.
import matplotlib.pyplot as plt
def graph_cust(cust_type):
"""function produces a graph of day agaist customer number for a given customer type"""
s = show_all_states_list(cust_type)
x = list(i['day']for i in s)
y1 = list(i['custtypeA_nondp'] for i in s)
y2 = list(i['custtypeA_dp']for i in s)
plt.scatter(x,y1,color= 'k')
plt.scatter(x,y2,color='g')
plt.show()
Thanks
You can set the limits on either axis using plt.xlim(x_low, x_high). If you do not want to manually set the upper limit (e.g your happy with the current upper limit) then try:
ax = plt.subplot(111) # Create axis instance
ax.scatter(x, y1, color='k') # Same as you have above but use ax instead of plt
ax.set_xlim(0.0, ax.get_xlim()[1])
Note the slight difference here, we use an axis instance. This gives us the ability to return the current xlimits using ax.get_xlim() this returns a tuple (x_low, x_high) which we pick the second using [1].
A minimal example of a legend:
plt.plot(x, y, label="some text")
plt.legend()
For more on legends see any of these examples

Categories

Resources