I was wondering if it is possible to plot a curve in matplotlib with arrow ticks.
Something like:
from pylab import *
y = linspace(0,10,0.01)
x = cos(y)
plot(x, y, '->')
which should come out with a curve made like this --->---->----> when x increases and like this ---<----<----< whenit decreases (and for y as well, of course).
EDIT:
Furthermore, the arrows should be inclined in the curve's direction (for example, 45 degrees for the y=x function)
It is possible to use the same strategy as in matplotlib streamplot function. Based on the example already given by hitzg:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
def add_arrow_to_line2D(
axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
arrowstyle='-|>', arrowsize=1, transform=None):
"""
Add arrows to a matplotlib.lines.Line2D at selected locations.
Parameters:
-----------
axes:
line: Line2D object as returned by plot command
arrow_locs: list of locations where to insert arrows, % of total length
arrowstyle: style of the arrow
arrowsize: size of the arrow
transform: a matplotlib transform instance, default to data coordinates
Returns:
--------
arrows: list of arrows
"""
if not isinstance(line, mlines.Line2D):
raise ValueError("expected a matplotlib.lines.Line2D object")
x, y = line.get_xdata(), line.get_ydata()
arrow_kw = {
"arrowstyle": arrowstyle,
"mutation_scale": 10 * arrowsize,
}
color = line.get_color()
use_multicolor_lines = isinstance(color, np.ndarray)
if use_multicolor_lines:
raise NotImplementedError("multicolor lines not supported")
else:
arrow_kw['color'] = color
linewidth = line.get_linewidth()
if isinstance(linewidth, np.ndarray):
raise NotImplementedError("multiwidth lines not supported")
else:
arrow_kw['linewidth'] = linewidth
if transform is None:
transform = axes.transData
arrows = []
for loc in arrow_locs:
s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
n = np.searchsorted(s, s[-1] * loc)
arrow_tail = (x[n], y[n])
arrow_head = (np.mean(x[n:n + 2]), np.mean(y[n:n + 2]))
p = mpatches.FancyArrowPatch(
arrow_tail, arrow_head, transform=transform,
**arrow_kw)
axes.add_patch(p)
arrows.append(p)
return arrows
y = np.linspace(0, 100, 200)
x = np.cos(y/5.)
fig, ax = plt.subplots(1, 1)
# print the line and the markers in seperate steps
line, = ax.plot(x, y, 'k-')
add_arrow_to_line2D(ax, line, arrow_locs=np.linspace(0., 1., 200),
arrowstyle='->')
plt.show()
Also refer to this answer.
Try this:
import numpy as np
import matplotlib.pyplot as plt
y = np.linspace(0,100,100)
x = np.cos(y/5.)
# use masked arrays
x1 = np.ma.masked_array(x[:-1], np.diff(x)>=0)
x2 = np.ma.masked_array(x[:-1], np.diff(x)<=0)
# print the line and the markers in seperate steps
plt.plot(x, y, 'k-')
plt.plot(x1, y[:-1], 'k<')
plt.plot(x2, y[:-1], 'k>')
plt.show()
Related
I've been toying around with this problem and am close to what I want but missing that extra line or two.
Basically, I'd like to plot a single line whose color changes given the value of a third array. Lurking around I have found this works well (albeit pretty slowly) and represents the problem
import numpy as np
import matplotlib.pyplot as plt
c = np.arange(1,100)
x = np.arange(1,100)
y = np.arange(1,100)
cm = plt.get_cmap('hsv')
fig = plt.figure(figsize=(5,5))
ax1 = plt.subplot(111)
no_points = len(c)
ax1.set_color_cycle([cm(1.*i/(no_points-1))
for i in range(no_points-1)])
for i in range(no_points-1):
bar = ax1.plot(x[i:i+2],y[i:i+2])
plt.show()
Which gives me this:
I'd like to be able to include a colorbar along with this plot. So far I haven't been able to crack it just yet. Potentially there will be other lines included with different x,y's but the same c, so I was thinking that a Normalize object would be the right path.
Bigger picture is that this plot is part of a 2x2 sub plot grid. I am already making space for the color bar axes object with matplotlib.colorbar.make_axes(ax4), where ax4 with the 4th subplot.
Take a look at the multicolored_line example in the Matplotlib gallery and dpsanders' colorline notebook:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
def multicolored_lines():
"""
http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
http://matplotlib.org/examples/pylab_examples/multicolored_line.html
"""
x = np.linspace(0, 4. * np.pi, 100)
y = np.sin(x)
fig, ax = plt.subplots()
lc = colorline(x, y, cmap='hsv')
plt.colorbar(lc)
plt.xlim(x.min(), x.max())
plt.ylim(-1.0, 1.0)
plt.show()
def colorline(
x, y, z=None, cmap='copper', norm=plt.Normalize(0.0, 1.0),
linewidth=3, alpha=1.0):
"""
http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
http://matplotlib.org/examples/pylab_examples/multicolored_line.html
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
"""
# Default colors equally spaced on [0,1]:
if z is None:
z = np.linspace(0.0, 1.0, len(x))
# Special case if a single number:
# to check for numerical input -- this is a hack
if not hasattr(z, "__iter__"):
z = np.array([z])
z = np.asarray(z)
segments = make_segments(x, y)
lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
linewidth=linewidth, alpha=alpha)
ax = plt.gca()
ax.add_collection(lc)
return lc
def make_segments(x, y):
"""
Create list of line segments from x and y coordinates, in the correct format
for LineCollection: an array of the form numlines x (points per line) x 2 (x
and y) array
"""
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
return segments
multicolored_lines()
Note that calling plt.plot hundreds of times tends to kill performance.
Using a LineCollection to build multi-colored line segments is much much faster.
I'd like to add an arrow to a line plot with matplotlib like in the plot below (drawn with pgfplots).
How can I do (position and direction of the arrow should be parameters ideally)?
Here is some code to experiment.
from matplotlib import pyplot
import numpy as np
t = np.linspace(-2, 2, 100)
plt.plot(t, np.sin(t))
plt.show()
Thanks.
In my experience this works best by using annotate. Thereby you avoid the weird warping you get with ax.arrow which is somehow hard to control.
EDIT: I've wrapped it into a little function.
from matplotlib import pyplot as plt
import numpy as np
def add_arrow(line, position=None, direction='right', size=15, color=None):
"""
add an arrow to a line.
line: Line2D object
position: x-position of the arrow. If None, mean of xdata is taken
direction: 'left' or 'right'
size: size of the arrow in fontsize points
color: if None, line color is taken.
"""
if color is None:
color = line.get_color()
xdata = line.get_xdata()
ydata = line.get_ydata()
if position is None:
position = xdata.mean()
# find closest index
start_ind = np.argmin(np.absolute(xdata - position))
if direction == 'right':
end_ind = start_ind + 1
else:
end_ind = start_ind - 1
line.axes.annotate('',
xytext=(xdata[start_ind], ydata[start_ind]),
xy=(xdata[end_ind], ydata[end_ind]),
arrowprops=dict(arrowstyle="->", color=color),
size=size
)
t = np.linspace(-2, 2, 100)
y = np.sin(t)
# return the handle of the line
line = plt.plot(t, y)[0]
add_arrow(line)
plt.show()
It's not very intuitive but it works. You can then fiddle with the arrowprops dictionary until it looks right.
Just add a plt.arrow():
from matplotlib import pyplot as plt
import numpy as np
# your function
def f(t): return np.sin(t)
t = np.linspace(-2, 2, 100)
plt.plot(t, f(t))
plt.arrow(0, f(0), 0.01, f(0.01)-f(0), shape='full', lw=0, length_includes_head=True, head_width=.05)
plt.show()
EDIT: Changed parameters of arrow to include position & direction of function to draw.
Not the nicest solution, but should work:
import matplotlib.pyplot as plt
import numpy as np
def makeArrow(ax,pos,function,direction):
delta = 0.0001 if direction >= 0 else -0.0001
ax.arrow(pos,function(pos),pos+delta,function(pos+delta),head_width=0.05,head_length=0.1)
fun = np.sin
t = np.linspace(-2, 2, 100)
ax = plt.axes()
ax.plot(t, fun(t))
makeArrow(ax,0,fun,+1)
plt.show()
I know this doesn't exactly answer the question as asked, but I thought this could be useful to other people landing here. I wanted to include the arrow in my plot's legend, but the solutions here don't mention how. There may be an easier way to do this, but here is my solution:
To include the arrow in your legend, you need to make a custom patch handler and use the matplotlib.patches.FancyArrow object. Here is a minimal working solution. This solution piggybacks off of the existing solutions in this thread.
First, the imports...
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerPatch
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import numpy as np
Now, we make a custom legend handler. This handler can create legend artists for any line-patch combination, granted that the line has no markers.
class HandlerLinePatch(HandlerPatch):
def __init__(self, linehandle=None, **kw):
HandlerPatch.__init__(self, **kw)
self.linehandle=linehandle
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width,
height, fontsize, trans):
p = super().create_artists(legend, orig_handle,
xdescent, descent,
width, height, fontsize,
trans)
line = Line2D([0,width],[height/2.,height/2.])
if self.linehandle is None:
line.set_linestyle('-')
line._color = orig_handle._edgecolor
else:
self.update_prop(line, self.linehandle, legend)
line.set_drawstyle('default')
line.set_marker('')
line.set_transform(trans)
return [p[0],line]
Next, we write a function that specifies the type of patch we want to include in the legend - an arrow in our case. This is courtesy of Javier's answer here.
def make_legend_arrow(legend, orig_handle,
xdescent, ydescent,
width, height, fontsize):
p = patches.FancyArrow(width/2., height/2., width/5., 0,
length_includes_head=True, width=0,
head_width=height, head_length=height,
overhang=0.2)
return p
Next, a modified version of the add_arrow function from Thomas' answer that uses the FancyArrow patch rather than annotations. This solution might cause weird wrapping like Thomas warned against, but I couldn't figure out how to put the arrow in the legend if the arrow is an annotation.
def add_arrow(line, ax, position=None, direction='right', color=None, label=''):
"""
add an arrow to a line.
line: Line2D object
position: x-position of the arrow. If None, mean of xdata is taken
direction: 'left' or 'right'
color: if None, line color is taken.
label: label for arrow
"""
if color is None:
color = line.get_color()
xdata = line.get_xdata()
ydata = line.get_ydata()
if position is None:
position = xdata.mean()
# find closest index
start_ind = np.argmin(np.absolute(xdata - position))
if direction == 'right':
end_ind = start_ind + 1
else:
end_ind = start_ind - 1
dx = xdata[end_ind] - xdata[start_ind]
dy = ydata[end_ind] - ydata[start_ind]
size = abs(dx) * 5.
x = xdata[start_ind] + (np.sign(dx) * size/2.)
y = ydata[start_ind] + (np.sign(dy) * size/2.)
arrow = patches.FancyArrow(x, y, dx, dy, color=color, width=0,
head_width=size, head_length=size,
label=label,length_includes_head=True,
overhang=0.3, zorder=10)
ax.add_patch(arrow)
Now, a helper function to plot both the arrow and the line. It returns a Line2D object, which is needed for the legend handler we wrote in the first code block
def plot_line_with_arrow(x,y,ax=None,label='',**kw):
if ax is None:
ax = plt.gca()
line = ax.plot(x,y,**kw)[0]
add_arrow(line, ax, label=label)
return line
Finally, we make the plot and update the legend's handler_map with our custom handler.
t = np.linspace(-2, 2, 100)
y = np.sin(t)
line = plot_line_with_arrow(t,y,label='Path', linestyle=':')
plt.gca().set_aspect('equal')
plt.legend(handler_map={patches.FancyArrow :
HandlerLinePatch(patch_func=make_legend_arrow,
linehandle=line)})
plt.show()
Here is the output:
I've found that quiver() works better than arrow() or annotate() when the x and y axes have very different scales. Here's my helper function for plotting a line with arrows:
def plot_with_arrows(ax, x, y, color="g", label="", n_arrows=2):
ax.plot(x, y, rasterized=True, color=color, label=label)
x_range = x.max() - x.min()
y_range = y.max() - y.min()
for i in np.linspace(x.keys().min(), x.keys().max(), n_arrows * 2 + 1).astype(np.int32)[1::2]:
direction = np.array([(x[i+5] - x[i]), (y[i+5] - y[i])])
direction = direction / (np.sqrt(np.sum(np.power(direction, 2)))) * 0.05
direction[0] /= x_range
direction[1] /= y_range
ax.quiver(x[i], y[i], direction[0], direction[1], color=color)
I'd like to make a scatter plot where each point is colored by the spatial density of nearby points.
I've come across a very similar question, which shows an example of this using R:
R Scatter Plot: symbol color represents number of overlapping points
What's the best way to accomplish something similar in python using matplotlib?
In addition to hist2d or hexbin as #askewchan suggested, you can use the same method that the accepted answer in the question you linked to uses.
If you want to do that:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
# Generate fake data
x = np.random.normal(size=1000)
y = x * 3 + np.random.normal(size=1000)
# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)
fig, ax = plt.subplots()
ax.scatter(x, y, c=z, s=100)
plt.show()
If you'd like the points to be plotted in order of density so that the densest points are always on top (similar to the linked example), just sort them by the z-values. I'm also going to use a smaller marker size here as it looks a bit better:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
# Generate fake data
x = np.random.normal(size=1000)
y = x * 3 + np.random.normal(size=1000)
# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)
# Sort the points by density, so that the densest points are plotted last
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]
fig, ax = plt.subplots()
ax.scatter(x, y, c=z, s=50)
plt.show()
Plotting >100k data points?
The accepted answer, using gaussian_kde() will take a lot of time. On my machine, 100k rows took about 11 minutes. Here I will add two alternative methods (mpl-scatter-density and datashader) and compare the given answers with same dataset.
In the following, I used a test data set of 100k rows:
import matplotlib.pyplot as plt
import numpy as np
# Fake data for testing
x = np.random.normal(size=100000)
y = x * 3 + np.random.normal(size=100000)
Output & computation time comparison
Below is a comparison of different methods.
1: mpl-scatter-density
Installation
pip install mpl-scatter-density
Example code
import mpl_scatter_density # adds projection='scatter_density'
from matplotlib.colors import LinearSegmentedColormap
# "Viridis-like" colormap with white background
white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
(0, '#ffffff'),
(1e-20, '#440053'),
(0.2, '#404388'),
(0.4, '#2a788e'),
(0.6, '#21a784'),
(0.8, '#78d151'),
(1, '#fde624'),
], N=256)
def using_mpl_scatter_density(fig, x, y):
ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
density = ax.scatter_density(x, y, cmap=white_viridis)
fig.colorbar(density, label='Number of points per pixel')
fig = plt.figure()
using_mpl_scatter_density(fig, x, y)
plt.show()
Drawing this took 0.05 seconds:
And the zoom-in looks quite nice:
2: datashader
Datashader is an interesting project. It has added support for matplotlib in datashader 0.12.
Installation
pip install datashader
Code (source & parameterer listing for dsshow):
import datashader as ds
from datashader.mpl_ext import dsshow
import pandas as pd
def using_datashader(ax, x, y):
df = pd.DataFrame(dict(x=x, y=y))
dsartist = dsshow(
df,
ds.Point("x", "y"),
ds.count(),
vmin=0,
vmax=35,
norm="linear",
aspect="auto",
ax=ax,
)
plt.colorbar(dsartist)
fig, ax = plt.subplots()
using_datashader(ax, x, y)
plt.show()
It took 0.83 s to draw this:
There is also possibility to colorize by third variable. The third parameter for dsshow controls the coloring. See more examples here and the source for dsshow here.
3: scatter_with_gaussian_kde
def scatter_with_gaussian_kde(ax, x, y):
# https://stackoverflow.com/a/20107592/3015186
# Answer by Joel Kington
xy = np.vstack([x, y])
z = gaussian_kde(xy)(xy)
ax.scatter(x, y, c=z, s=100, edgecolor='')
It took 11 minutes to draw this:
4: using_hist2d
import matplotlib.pyplot as plt
def using_hist2d(ax, x, y, bins=(50, 50)):
# https://stackoverflow.com/a/20105673/3015186
# Answer by askewchan
ax.hist2d(x, y, bins, cmap=plt.cm.jet)
It took 0.021 s to draw this bins=(50,50):
It took 0.173 s to draw this bins=(1000,1000):
Cons: The zoomed-in data does not look as good as in with mpl-scatter-density or datashader. Also you have to determine the number of bins yourself.
5: density_scatter
The code is as in the answer by Guillaume.
It took 0.073 s to draw this with bins=(50,50):
It took 0.368 s to draw this with bins=(1000,1000):
Also, if the number of point makes KDE calculation too slow, color can be interpolated in np.histogram2d [Update in response to comments: If you wish to show the colorbar, use plt.scatter() instead of ax.scatter() followed by plt.colorbar()]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from scipy.interpolate import interpn
def density_scatter( x , y, ax = None, sort = True, bins = 20, **kwargs ) :
"""
Scatter plot colored by 2d histogram
"""
if ax is None :
fig , ax = plt.subplots()
data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)
#To be sure to plot all data
z[np.where(np.isnan(z))] = 0.0
# Sort the points by density, so that the densest points are plotted last
if sort :
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]
ax.scatter( x, y, c=z, **kwargs )
norm = Normalize(vmin = np.min(z), vmax = np.max(z))
cbar = fig.colorbar(cm.ScalarMappable(norm = norm), ax=ax)
cbar.ax.set_ylabel('Density')
return ax
if "__main__" == __name__ :
x = np.random.normal(size=100000)
y = x * 3 + np.random.normal(size=100000)
density_scatter( x, y, bins = [30,30] )
You could make a histogram:
import numpy as np
import matplotlib.pyplot as plt
# fake data:
a = np.random.normal(size=1000)
b = a*3 + np.random.normal(size=1000)
plt.hist2d(a, b, (50, 50), cmap=plt.cm.jet)
plt.colorbar()
I am trying to draw a series of lines. The lines are all the same length, and randomly switch colors for a random length (blue to orange). I am drawing the lines in blue and then overlaying orange on top. You can see from my picture there are clipped parts of the lines where it is grey. I cannot figure out why this is happening. Also related I believe is that my labels are not moving to a left alignment like they should. Any help is greatly appreciated.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import random
plt.close('all')
fig, ax = plt.subplots(figsize=(15,11))
def label(xy, text):
y = xy[1] - 2
ax.text(xy[0], y, text, ha="left", family='sans-serif', size=14)
def draw_chromosome(start, stop, y, color):
x = np.array([start, stop])
y = np.array([y, y])
line = mlines.Line2D(x , y, lw=10., color=color)
ax.add_line(line)
x = 50
y = 100
chr = 1
for i in range(22):
draw_chromosome(x, 120, y, "#1C2F4D")
j = 0
while j < 120:
print j
length = 1
if random.randint(1, 100) > 90:
length = random.randint(1, 120-j)
draw_chromosome(j, j+length, y, "#FA9B00")
j = j+length+1
label([x, y], "Chromosome%i" % chr)
y -= 3
chr += 1
plt.axis('equal')
plt.axis('off')
plt.tight_layout()
plt.show()
You're only drawing the blue background from x = 50 to x = 120.
Replace this line:
draw_chromosome(x, 120, y, "#1C2F4D")
with this:
draw_chromosome(0, 120, y, "#1C2F4D")
To draw the blue line all the way across.
Alternately, if you also want to move your labels to the left, you can just set x=0 instead of setting it to 50.
I suggest using LineCollection for this. Below is a little helper function I wrote based on the example at http://matplotlib.org/examples/pylab_examples/multicolored_line.html (it looks long, but there is a lot of comments + docstrings)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.ticker import NullLocator
from collections import OrderedDict
def binary_state_lines(ax, chrom_data, xmin=0, xmax=120,
delta_y=3,
off_color = "#1C2F4D",
on_color = "#FA9B00"):
"""
Draw a whole bunch of chromosomes
Parameters
----------
ax : Axes
The axes to draw stuff to
chrom_data : OrderedDict
The chromosome data as a dict, key on the label with a list of pairs
of where the data is 'on'. Data is plotted top-down
xmin, xmax : float, optional
The minimum and maximum limits for the x values
delta_y : float, optional
The spacing between lines
off_color, on_color : color, optional
The colors to use for the the on/off state
Returns
-------
collections : dict
dictionary of the collections added keyed on the label
"""
# base offset
y_val = 0
# make the color map and norm
cmap = ListedColormap([off_color, on_color])
norm = BoundaryNorm([0, 0.5, 1], cmap.N)
# sort out where the text should be
txt_x = (xmax + xmin) / 2
# dictionary to hold the returned artists
ret = dict()
# loop over the input data draw each collection
for label, data in chrom_data.items():
# increment the y offset
y_val += delta_y
# turn the high windows on to alternating
# high/low regions
x = np.asarray(data).ravel()
# assign the high/low state to each one
state = np.mod(1 + np.arange(len(x)), 2)
# deal with boundary conditions to be off
# at start/end
if x[0] > xmin:
x = np.r_[xmin, x]
state = np.r_[0, state]
if x[-1] < xmax:
x = np.r_[x, xmax]
state = np.r_[state, 0]
# make the matching y values
y = np.ones(len(x)) * y_val
# call helper function to create the collection
coll = draw_segments(ax, x, y, state,
cmap, norm)
ret[label] = coll
# set up the axes limits
ax.set_xlim(xmin, xmax)
ax.set_ylim(0, y_val + delta_y)
# turn off x-ticks
ax.xaxis.set_major_locator(NullLocator())
# make the y-ticks be labeled as per the input
ax.yaxis.set_ticks((1 + np.arange(len(chrom_data))) * delta_y)
ax.yaxis.set_ticklabels(list(chrom_data.keys()))
# invert so that the first data is at the top
ax.invert_yaxis()
# turn off the frame and patch
ax.set_frame_on(False)
# return the added artists
return ret
def draw_segments(ax, x, y, state, cmap, norm, lw=10):
"""
helper function to turn boundary edges into the input LineCollection
expects.
Parameters
----------
ax : Axes
The axes to draw to
x, y, state : array
The x edges, the y values and the state of each region
cmap : matplotlib.colors.Colormap
The color map to use
norm : matplotlib.ticker.Norm
The norm to use with the color map
lw : float, optional
The width of the lines
"""
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = LineCollection(segments, cmap=cmap, norm=norm)
lc.set_array(state)
lc.set_linewidth(lw)
ax.add_collection(lc)
return lc
An example:
synthetic_data = OrderedDict()
for j in range(21):
key = 'data {:02d}'.format(j)
synthetic_data[key] = np.cumsum(np.random.randint(1, 10, 20)).reshape(-1, 2)
fig, ax = plt.subplots(tight_layout=True)
binary_state_lines(ax, synthetic_data, xmax=120)
plt.show()
Separating the plotting logic from everything else will make your code easier to maintain and more reusable.
I also took the liberty of moving your labels from between the lines (where they can be ambiguous) to the yaxis tick labels.
In the figure below, each unit in the x-axis represents a 10mins interval. I would like to customize the labels of x-axis, so that it shows hours, i.e. it displays a ticker every 6 units (60mins). I am new to matplotlib. Could someone help me? Thanks~
Here is the code for the above figure.
x = arange(0, size_x, dx)
y = arange(0, size_y, dy)
X,Y = meshgrid(x, y)
Z = foo(x,y)
pcolor(X, Y, Z, cmap=cm.Reds)
colorbar()
axis([0,size_x-1,0,size_y-1])
show()
There's more than one way to do this.
Let's start out with an example plot:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
# Generate some data...
x, y = np.mgrid[:141, :101]
z = np.cos(np.hypot(x, y))
# Plot the figure...
plt.pcolormesh(x, y, z, cmap=mpl.cm.Reds)
plt.show()
The simple way to do what you want would be something like this:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
# Generate some data...
x, y = np.mgrid[:141, :101]
z = np.cos(np.hypot(x, y))
# Plot the figure...
plt.pcolormesh(x, y, z, cmap=mpl.cm.Reds)
# Set the ticks and labels...
ticks = np.arange(x.min(), x.max(), 6)
labels = range(ticks.size)
plt.xticks(ticks, labels)
plt.xlabel('Hours')
plt.show()
The other way involves subclassing matplotlib's locators and tickers.
For your purposes, the example above is fine.
The advantage of making new locators and tickers is that the axis will automatically be scaled into reasonable intervals of the "dx" units you specify. If you're using it as a part of a larger application, it can be worthwhile. For a single plot, it's more trouble than it's worth.
If you really wanted to go that route, though, you'd do something like this:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
def main():
# Generate some data...
x, y = np.mgrid[:141, :101]
z = np.cos(np.hypot(x, y))
# Plot the figure...
fig, ax = plt.subplots()
ax.pcolormesh(x, y, z, cmap=mpl.cm.Reds)
ax.set_xlabel('Hours')
ax.xaxis.set_major_locator(ScaledLocator(dx=6))
ax.xaxis.set_major_formatter(ScaledFormatter(dx=6))
plt.show()
class ScaledLocator(mpl.ticker.MaxNLocator):
"""
Locates regular intervals along an axis scaled by *dx* and shifted by
*x0*. For example, this would locate minutes on an axis plotted in seconds
if dx=60. This differs from MultipleLocator in that an approriate interval
of dx units will be chosen similar to the default MaxNLocator.
"""
def __init__(self, dx=1.0, x0=0.0):
self.dx = dx
self.x0 = x0
mpl.ticker.MaxNLocator.__init__(self, nbins=9, steps=[1, 2, 5, 10])
def rescale(self, x):
return x / self.dx + self.x0
def inv_rescale(self, x):
return (x - self.x0) * self.dx
def __call__(self):
vmin, vmax = self.axis.get_view_interval()
vmin, vmax = self.rescale(vmin), self.rescale(vmax)
vmin, vmax = mpl.transforms.nonsingular(vmin, vmax, expander = 0.05)
locs = self.bin_boundaries(vmin, vmax)
locs = self.inv_rescale(locs)
prune = self._prune
if prune=='lower':
locs = locs[1:]
elif prune=='upper':
locs = locs[:-1]
elif prune=='both':
locs = locs[1:-1]
return self.raise_if_exceeds(locs)
class ScaledFormatter(mpl.ticker.OldScalarFormatter):
"""Formats tick labels scaled by *dx* and shifted by *x0*."""
def __init__(self, dx=1.0, x0=0.0, **kwargs):
self.dx, self.x0 = dx, x0
def rescale(self, x):
return x / self.dx + self.x0
def __call__(self, x, pos=None):
xmin, xmax = self.axis.get_view_interval()
xmin, xmax = self.rescale(xmin), self.rescale(xmax)
d = abs(xmax - xmin)
x = self.rescale(x)
s = self.pprint_val(x, d)
return s
if __name__ == '__main__':
main()