How to add a legend to matplotlib scatter plot - python

I'm attempting to plot a PCA and one of the colours is label 1 and the other should be label 2. When I want to add a legend with ax1.legend() I only get the label for the blue dot or no label at all. How can I add the legend with the correct labels for both the blue and purple dots?
sns.set(style = 'darkgrid')
fig, ax1 = sns.plt.subplots()
x1, x2 = X_bar[:,0], X_bar[:,1]
ax1.scatter(x1, x2, 100, edgecolors='none', c = colors)
fig.set_figheight(8)
fig.set_figwidth(15)

It looks like you are plotting each point oscillating between two colours. As per the answer to this question subsampling every nth entry in a numpy array You can use numpys array slicing to plot two separate arrays, then do legend as normal.
For some sample data:
import numpy as np
import numpy.random as nprnd
import matplotlib.pyplot as plt
A = nprnd.randint(1000, size=100)
A.shape = (50,2)
x1, x2 = np.sort(A[:,0], axis=0), np.sort(A[:,1], axis=0)
x1
Out[50]:
array([ 46, 63, 84, 96, 118, 127, 137, 142, 181, 187, 187, 207, 210,
238, 238, 330, 334, 335, 346, 346, 350, 392, 400, 426, 467, 531,
550, 567, 569, 572, 583, 625, 637, 661, 671, 677, 698, 713, 777,
796, 837, 850, 866, 868, 874, 890, 919, 972, 992, 993])
x2
Out[51]:
array([ 2, 44, 49, 51, 72, 84, 86, 118, 120, 133, 150, 155, 156,
159, 199, 202, 250, 281, 289, 317, 317, 386, 405, 414, 427, 461,
507, 510, 543, 552, 553, 555, 559, 576, 618, 622, 633, 647, 665,
672, 682, 685, 745, 767, 776, 802, 808, 813, 847, 973])
labels=['blue','red']
fig, ax1 = plt.subplots()
ax1.scatter(x1[0::2], x2[0::2], 100, edgecolors='none', c='red', label = 'red')
ax1.scatter(x1[1::2], x2[1::2], 100, edgecolors='none', c='black', label = 'black')
plt.legend()
plt.show()
For your code, you can do:
sns.set(style = 'darkgrid')
fig, ax1 = sns.plt.subplots()
x1, x2 = X_bar[:,0], X_bar[:,1]
ax1.scatter(x1[0::2], x2[0::2], 100, edgecolors='none', c = colors[0], label='one')
ax1.scatter(x1[1::2], x2[1::2], 100, edgecolors='none', c = colors[1], label='two')
fig.set_figheight(8)
fig.set_figwidth(15)
plt.legend()

Related

Pyplot not plotting marker for detected peaks

I'm writing a Python script that plots a candlestick chart of with x markers indicating peak candlesticks. The used data is a series of USD/JPY rates read using pandas.read_csv() from a csv file provided by Oanda API. The result of pandas.DataFrame.head() is as follows:
time close open high low volume
0 2016/08/19 06:00:00 100.256 99.919 100.471 99.887 30965
1 2016/08/22 06:00:00 100.335 100.832 100.944 100.221 32920
2 2016/08/23 06:00:00 100.253 100.339 100.405 99.950 26069
3 2016/08/24 06:00:00 100.460 100.270 100.619 100.104 22340
4 2016/08/25 06:00:00 100.546 100.464 100.627 100.314 17224
While the candlestick chart itself is displayed properly (although it needs some foramtting), I don't see any markers on it.
What I expect is something like an example graph output shown on the scipy.signal.find_peaks document, only it is a candlestick chart instead of a line graph.
Here is my code:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import mpl_finance
df = pd.read_csv(sys.argv[1])
opens = df['open']
highs = df['high']
lows = df['low']
closes = df['close']
indices = find_peaks(highs)[0]
fig = plt.figure(figsize=(12, 4))
ax1 = fig.add_subplot(1, 1, 1)
mpl_finance.candlestick2_ohlc(ax1, opens, highs, lows, closes, width=4, colorup='k', colordown='r', alpha=0.75)
ax1.plot(x=indices, y=[highs[j] for j in indices], fmt="x", label="peak highs")
ax1.grid()
plt.show()
I suspected it's either the x or y parameter of ax1.plot() being empty, which is shown otherwise using pdb debugger:
-> ax1.plot(x=indices, y=[highs[j] for j in indices], fmt="x", label="peak highs")
(Pdb) indices
array([ 1, 10, 15, 18, 23, 25, 29, 34, 39, 47, 50, 59, 66,
70, 74, 76, 78, 81, 84, 87, 92, 95, 99, 101, 107, 113,
118, 126, 130, 138, 143, 145, 158, 161, 164, 170, 172, 176, 182,
186, 196, 203, 208, 215, 220, 222, 226, 230, 233, 237, 241, 246,
248, 256, 261, 263, 267, 282, 286, 290, 293, 296, 304, 306, 308,
310, 313, 316, 322, 331, 336, 342, 349, 352, 359, 367, 369, 373,
378, 382, 391, 395, 400, 403, 405, 411, 416, 422, 425, 428, 438,
441, 444, 447, 450, 454, 459, 466, 471, 473, 477, 485, 493, 497],
dtype=int32)
(Pdb) [highs[j] for j in indices]
[100.944, 104.33, 103.07, 103.367, 102.79799999999999, 101.258, 101.851, 104.17399999999999, 104.64299999999999, 104.882, 105.544, 106.95700000000001, 111.375, 113.911, 114.837, 114.78399999999999, 114.415, 116.134, 118.676, 118.251, 117.822, 118.624, 117.54299999999999, 116.89, 115.634, 115.38600000000001, 113.538, 114.962, 113.787, 114.765, 115.512, 115.2, 112.213, 111.48, 111.587, 109.23299999999999, 109.5, 111.79, 113.05799999999999, 114.39299999999999, 112.135, 111.721, 110.823, 111.8, 112.47399999999999, 112.935, 113.696, 114.505, 113.583, 112.429, 112.21600000000001, 110.99, 111.05799999999999, 110.95700000000001, 109.833, 109.85600000000001, 110.678, 112.72399999999999, 113.264, 113.20200000000001, 113.446, 112.834, 113.589, 114.10700000000001, 114.25, 114.462, 114.288, 114.742, 113.91799999999999, 111.70100000000001, 113.095, 113.758, 113.64399999999999, 113.398, 113.39299999999999, 111.49, 111.23200000000001, 109.77799999999999, 110.491, 109.79, 107.912, 107.685, 106.47, 107.06200000000001, 107.305, 106.65, 107.01799999999999, 107.499, 107.405, 107.788, 109.552, 110.044, 109.406, 110.02600000000001, 110.461, 111.40299999999999, 109.84899999999999, 110.275, 110.85799999999999, 110.91, 110.765, 111.14399999999999, 112.80799999999999, 113.18700000000001]
Could anyone give me a possible solution or an explanation of the cause?

Fill area of overlap between two normal distributions in seaborn / matplotlib

I want to fill the area overlapping between two normal distributions. I've got the x min and max, but I can't figure out how to set the y boundaries.
I've looked at the plt documentation and some examples. I think this related question and this one come close, but no luck. Here's what I have so far.
import numpy as np
import seaborn as sns
import scipy.stats as stats
import matplotlib.pyplot as plt
pepe_calories = np.array([361, 291, 263, 284, 311, 284, 282, 228, 328, 263, 354, 302, 293,
254, 297, 281, 307, 281, 262, 302, 244, 259, 273, 299, 278, 257,
296, 237, 276, 280, 291, 278, 251, 313, 314, 323, 333, 270, 317,
321, 307, 256, 301, 264, 221, 251, 307, 283, 300, 292, 344, 239,
288, 356, 224, 246, 196, 202, 314, 301, 336, 294, 237, 284, 311,
257, 255, 287, 243, 267, 253, 257, 320, 295, 295, 271, 322, 343,
313, 293, 298, 272, 267, 257, 334, 276, 337, 325, 261, 344, 298,
253, 302, 318, 289, 302, 291, 343, 310, 241])
modern_calories = np.array([310, 315, 303, 360, 339, 416, 278, 326, 316, 314, 333, 317, 357,
304, 363, 387, 279, 350, 367, 321, 366, 311, 308, 303, 299, 363,
335, 357, 392, 321, 361, 285, 321, 290, 392, 341, 331, 338, 326,
314, 327, 320, 293, 333, 297, 315, 365, 408, 352, 359, 312, 300,
263, 358, 345, 360, 336, 378, 315, 354, 318, 300, 372, 305, 336,
286, 296, 413, 383, 328, 418, 388, 416, 371, 313, 321, 321, 317,
402, 290, 328, 344, 330, 319, 309, 327, 351, 324, 278, 369, 416,
359, 381, 324, 306, 350, 385, 335, 395, 308])
ax = sns.distplot(pepe_calories, fit_kws={"color":"blue"}, kde=False,
fit=stats.norm, hist=None, label="Pepe's");
ax = sns.distplot(modern_calories, fit_kws={"color":"orange"}, kde=False,
fit=stats.norm, hist=None, label="Modern");
# Get the two lines from the axes to generate shading
l1 = ax.lines[0]
l2 = ax.lines[1]
# Get the xy data from the lines so that we can shade
x1 = l1.get_xydata()[:,0]
y1 = l1.get_xydata()[:,1]
x2 = l2.get_xydata()[:,0]
y2 = l2.get_xydata()[:,1]
x2min = np.min(x2)
x1max = np.max(x1)
ax.fill_between(x1,y1, where = ((x1 > x2min) & (x1 < x1max)), color="red", alpha=0.3)
#> <matplotlib.collections.PolyCollection at 0x1a200510b8>
plt.legend()
#> <matplotlib.legend.Legend at 0x1a1ff2e390>
plt.show()
Any ideas?
Created on 2018-12-01 by the reprexpy package
import reprexpy
print(reprexpy.SessionInfo())
#> Session info --------------------------------------------------------------------
#> Platform: Darwin-18.2.0-x86_64-i386-64bit (64-bit)
#> Python: 3.6
#> Date: 2018-12-01
#> Packages ------------------------------------------------------------------------
#> matplotlib==2.1.2
#> numpy==1.15.4
#> reprexpy==0.1.1
#> scipy==1.1.0
#> seaborn==0.9.0
While gathering the pdf data from get_xydata is clever, you are now at the mercy of matplotlib's rendering / segmentation algorithm. Having x1 and x2 span different ranges also makes comparing y1 and y2 difficult.
You can avoid these problems by fitting the normals yourself instead of
letting sns.distplot do it. Then you have more control over the values you are
looking for.
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
norm = stats.norm
pepe_calories = np.array([361, 291, 263, 284, 311, 284, 282, 228, 328, 263, 354, 302, 293,
254, 297, 281, 307, 281, 262, 302, 244, 259, 273, 299, 278, 257,
296, 237, 276, 280, 291, 278, 251, 313, 314, 323, 333, 270, 317,
321, 307, 256, 301, 264, 221, 251, 307, 283, 300, 292, 344, 239,
288, 356, 224, 246, 196, 202, 314, 301, 336, 294, 237, 284, 311,
257, 255, 287, 243, 267, 253, 257, 320, 295, 295, 271, 322, 343,
313, 293, 298, 272, 267, 257, 334, 276, 337, 325, 261, 344, 298,
253, 302, 318, 289, 302, 291, 343, 310, 241])
modern_calories = np.array([310, 315, 303, 360, 339, 416, 278, 326, 316, 314, 333, 317, 357,
304, 363, 387, 279, 350, 367, 321, 366, 311, 308, 303, 299, 363,
335, 357, 392, 321, 361, 285, 321, 290, 392, 341, 331, 338, 326,
314, 327, 320, 293, 333, 297, 315, 365, 408, 352, 359, 312, 300,
263, 358, 345, 360, 336, 378, 315, 354, 318, 300, 372, 305, 336,
286, 296, 413, 383, 328, 418, 388, 416, 371, 313, 321, 321, 317,
402, 290, 328, 344, 330, 319, 309, 327, 351, 324, 278, 369, 416,
359, 381, 324, 306, 350, 385, 335, 395, 308])
pepe_params = norm.fit(pepe_calories)
modern_params = norm.fit(modern_calories)
xmin = min(pepe_calories.min(), modern_calories.min())
xmax = max(pepe_calories.max(), modern_calories.max())
x = np.linspace(xmin, xmax, 100)
pepe_pdf = norm(*pepe_params).pdf(x)
modern_pdf = norm(*modern_params).pdf(x)
y = np.minimum(modern_pdf, pepe_pdf)
fig, ax = plt.subplots()
ax.plot(x, pepe_pdf, label="Pepe's", color='blue')
ax.plot(x, modern_pdf, label="Modern", color='orange')
ax.fill_between(x, y, color='red', alpha=0.3)
plt.legend()
plt.show()
If, let's say, sns.distplot (or some other plotting function) made a plot that you did not want to have to reproduce, then you could use the data from get_xydata this way:
import numpy as np
import seaborn as sns
import scipy.stats as stats
import matplotlib.pyplot as plt
pepe_calories = np.array([361, 291, 263, 284, 311, 284, 282, 228, 328, 263, 354, 302, 293,
254, 297, 281, 307, 281, 262, 302, 244, 259, 273, 299, 278, 257,
296, 237, 276, 280, 291, 278, 251, 313, 314, 323, 333, 270, 317,
321, 307, 256, 301, 264, 221, 251, 307, 283, 300, 292, 344, 239,
288, 356, 224, 246, 196, 202, 314, 301, 336, 294, 237, 284, 311,
257, 255, 287, 243, 267, 253, 257, 320, 295, 295, 271, 322, 343,
313, 293, 298, 272, 267, 257, 334, 276, 337, 325, 261, 344, 298,
253, 302, 318, 289, 302, 291, 343, 310, 241])
modern_calories = np.array([310, 315, 303, 360, 339, 416, 278, 326, 316, 314, 333, 317, 357,
304, 363, 387, 279, 350, 367, 321, 366, 311, 308, 303, 299, 363,
335, 357, 392, 321, 361, 285, 321, 290, 392, 341, 331, 338, 326,
314, 327, 320, 293, 333, 297, 315, 365, 408, 352, 359, 312, 300,
263, 358, 345, 360, 336, 378, 315, 354, 318, 300, 372, 305, 336,
286, 296, 413, 383, 328, 418, 388, 416, 371, 313, 321, 321, 317,
402, 290, 328, 344, 330, 319, 309, 327, 351, 324, 278, 369, 416,
359, 381, 324, 306, 350, 385, 335, 395, 308])
ax = sns.distplot(pepe_calories, fit_kws={"color":"blue"}, kde=False,
fit=stats.norm, hist=None, label="Pepe's");
ax = sns.distplot(modern_calories, fit_kws={"color":"orange"}, kde=False,
fit=stats.norm, hist=None, label="Modern");
# Get the two lines from the axes to generate shading
l1 = ax.lines[0]
l2 = ax.lines[1]
# Get the xy data from the lines so that we can shade
x1, y1 = l1.get_xydata().T
x2, y2 = l2.get_xydata().T
xmin = max(x1.min(), x2.min())
xmax = min(x1.max(), x2.max())
x = np.linspace(xmin, xmax, 100)
y1 = np.interp(x, x1, y1)
y2 = np.interp(x, x2, y2)
y = np.minimum(y1, y2)
ax.fill_between(x, y, color="red", alpha=0.3)
plt.legend()
plt.show()
I suppose not using seaborn in cases where you want to have full control over the resulting plot is often a useful strategy. Hence just calculate the fits, plot them and use fill between the curves up to the point where they cross each other.
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
pepe_calories = np.array(...)
modern_calories = np.array(...)
x = np.linspace(150,470,1000)
y1 = stats.norm.pdf(x, *stats.norm.fit(pepe_calories))
y2 = stats.norm.pdf(x, *stats.norm.fit(modern_calories))
cross = x[y1-y2 <= 0][0]
fig, ax = plt.subplots()
ax.fill_between(x,y1,y2, where=(x<=cross), color="red", alpha=0.3)
ax.plot(x,y1, label="Pepe's")
ax.plot(x,y2, label="Modern")
ax.legend()
plt.show()

Replace pixel values in a Numpy image

I have a zeros image with dimension 720*1280 and I have a list of pixels' coordinates to change:
x = [623, 623, 583, 526, 571, 669, 686, 697, 600, 594, 606, 657, 657, 657, 617, 646, 611, 657, 674, 571, 693, 688, 698, 700, 686, 687, 687, 693, 690, 686, 694]
y = [231, 281, 270, 270, 202, 287, 366, 428, 422, 517, 608, 422, 518, 608, 208, 214, 208, 231, 653, 652, 436, 441, 457, 457, 453, 461, 467, 469, 475, 477, 467]
here is the scatter plot :
yy= [720 -x for x in y]
plt.scatter(x, yy, s = 25, c = "r")
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(0, 1280)
plt.ylim(0, 720)
plt.show()
here is the code to generate binary image by set the pixel value to 255
image_zeros = np.zeros((720, 1280), dtype=np.uint8)
for i ,j in zip (x, y):
image_zeros[i, j] = 255
plt.imshow(image_zeros, cmap='gray')
plt.show()
here is the result : What is the problem!!
As Goyo pointed out, the resolution of the image is the problem. The default figure size is 6.4 inches by 4.8 inches, and the default resolution is 100 dpi (at least for the current version of matplotlib). So the default image size is 640 x 480. The figure includes not only the imshow image, but also the tickmarks, ticklabels and the x and y axis and a white border. So there are are even fewer than 640 x 480 pixels available for the imshow image by default.
Your image_zeros has shape (720, 1280). The array is too large to be fully rendered in an image of 640 x 480 pixels.
Thus, to generate white dots using imshow, set the figsize and dpi so that the number of pixels available for the imshow image is bigger than (1280, 720):
import numpy as np
import matplotlib.pyplot as plt
x = np.array([623, 623, 583, 526, 571, 669, 686, 697, 600, 594, 606, 657, 657, 657, 617, 646, 611, 657, 674, 571, 693, 688, 698, 700, 686, 687, 687, 693, 690, 686, 694])
y = np.array([231, 281, 270, 270, 202, 287, 366, 428, 422, 517, 608, 422, 518, 608, 208, 214, 208, 231, 653, 652, 436, 441, 457, 457, 453, 461, 467, 469, 475, 477, 467])
image_zeros = np.zeros((720, 1280), dtype=np.uint8)
image_zeros[y, x] = 255
fig, ax = plt.subplots(figsize=(26, 16), dpi=100)
ax.imshow(image_zeros, cmap='gray', origin='lower')
fig.savefig('/tmp/out.png')
Here is a closeup showing some of the white dots:
To make the white dots easier to see, you may wish to use scatter instead of imshow:
import numpy as np
import matplotlib.pyplot as plt
x = np.array([623, 623, 583, 526, 571, 669, 686, 697, 600, 594, 606, 657, 657, 657, 617, 646, 611, 657, 674, 571, 693, 688, 698, 700, 686, 687, 687, 693, 690, 686, 694])
y = np.array([231, 281, 270, 270, 202, 287, 366, 428, 422, 517, 608, 422, 518, 608, 208, 214, 208, 231, 653, 652, 436, 441, 457, 457, 453, 461, 467, 469, 475, 477, 467])
yy = 720 - y
fig, ax = plt.subplots()
ax.patch.set_facecolor('black')
ax.scatter(x, yy, s=25, c='white')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_xlim(0, 1280)
ax.set_ylim(0, 720)
fig.savefig('/tmp/out-scatter.png')

Imposing one graph onto another

I'm in a physics lab class and we have to write some code to analyze some data we have collected. My question is simple and probably stupid but I was just wondering how to plot a graph on top of another graph using python. Here is my code so far thanks
%pylab
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
#SIGNAL DATA
dataSig = [658, 679, 683, 691, 693, 693, 695, 696, 696, 696, 697, 699, 699, 700, 700, 700, 702, 703, 703, 704, 706, 706, 708, 708, 709, 709, 712, 712, 713, 714, 714, 715, 715, 715, 716, 716, 716, 717, 717, 717, 718, 718, 718, 718, 719, 720, 720, 721, 721, 721, 722, 723, 723, 724, 725, 725, 725, 726, 726, 726, 727, 727, 728, 728, 729, 730, 730, 731, 731, 731, 731, 732, 732, 733, 734, 734, 734, 734, 735, 736, 737, 738, 738, 738, 738, 740, 740, 741, 741, 741, 742, 743, 743, 743, 743, 743, 743, 743, 744, 744, 745, 746, 746, 746, 746, 747, 747, 747, 747, 748, 749, 749, 750, 750, 750, 750, 751, 751, 751, 751, 752, 752, 752, 754, 754, 756, 756, 757, 757, 757, 759, 759, 760, 760, 760, 762, 762, 762, 762, 762, 762, 763, 764, 765, 765, 765, 765, 766, 766, 766, 767, 767, 768, 769, 769, 770, 770, 771, 773, 775, 776, 780, 786, 786, 786, 787, 790, 790, 793, 796, 797, 798, 817, 823]
#[658,679,683,691,693,695,696,697,699,700,702,703,704,706,708,709,712,713,714,715,716,717,718,719,720,721,722,723,724,725,726,727,728,729,730,731,732,733,734,735,736,737,738,740,741,742,743,744,745,746,747,748,749,750,751,752,754,756,757,759,760,762,763,764,765,766,767,768,769,770,771,773,775,776,780,786,787,790,793,796,797,798,817,823] #[1,1,1,1,1,1,3,1,2,3,1,2,1,2,2,1,2,1,2,3,2,3,3,1,2,3,1,2,1,3,3,2,2,1,1,4,2,1,4,1,1,1,4,2,3,1,7,2,1,4,4,1,2,4,4,3,2,2,2,2,3,6,1,1,4,3,2,1,2,2,1,1,1,1,1,3,1,2,1,1,1,1,1,1]
#SIGNAL DEFINED VARIABLES
ntestpoints = 175
themean = 739.1
#sigma = ?
#amp = center/guassian
#SIGNAL GAUSSIAN FITTING FUNCTION
def mygauss(x, amp, center, sigma):
"""This is an example gaussian function, which takes in x values, the amplitude (amp),
the center x value (center) and the sigma of the Gaussian, and returns the respective y values."""
y = amp * np.exp(-.5*((x-center)/sigma)**2)
return y
#SIGNAL PLOT, NO GAUSS
plt.figure(figsize=(10,6))
plt.hist(dataSig,bins=ntestpoints/10,histtype="stepfilled",alpha=.5,color='g',range=[600,900])
plt.xlabel('Number of Counts/Second',fontsize=20)
plt.ylabel('Number of Measurements',fontsize=20)
plt.title('Measured Signal Count Rate Fitting with Gaussian Function',fontsize=22)
plt.axvline(themean,linestyle='-',color='r')
#plt.axvline(themean+error_on_mean,linestyle='--',color='b')
#plt.axvline(themean-error_on_mean,linestyle='--',color='b')
#plt.axvline(testmean,color='k',linestyle='-')
plt.show()
#------------------------------------------------------------
# define a function to make a gaussian with input values, used later
def mygauss(x, amp, center, sigma):
"""This is an example gaussian function, which takes in x values, the amplitude (amp),
the center x value (center) and the sigma of the Gaussian, and returns the respective y values."""
y = amp * np.exp(-.5*((x-center)/sigma)**2)
return y
npts = 40 # the number of points on the x axis
x = np.linspace(600,900,npts) # make a series of npts linearly spaced values between 0 and 10
amp = 40
center = 740.5
sigma = 40
y = mygauss(x, amp, center, sigma)
print y
plt.figure(figsize=(10,6))
plt.plot(x,y,'bo', label='data points')
plt.text(center, amp, "<-- peak is here",fontsize=16) # places text at any x/y location on the graph
plt.xlabel('X axis',fontsize=20)
plt.ylabel('Y axis', fontsize=20)
plt.title('A gaussian plot \n with some extras!',fontsize=20)
plt.legend(loc='best')
plt.show()
when you call for plt.figure() you are making a new plot area, in a different figure.
if you dont call for it the second time, you will plot in the same graph as the first one.
that however is not always a solution by itself, if they have very different scales, that can cause one graph to be massively outscaled by the other.
fortunately its not the case here so i wont get into details of how to use 2 different scales in one graph, but you can check it here (http://matplotlib.org/examples/api/two_scales.html)
by commenting the second plt.figure() from your code you get this:
hope it helps!
ps: next time try posting it with a matplotlib tag, it will get a fastter response than physics, since its basically a matplotlib question you had.

Matplotlib: colorbar breaks when using PySAL natural breaks

I'm making a choropleth map based on this tutorial.
But instead of splitting the data into equal intervals, like this:
bins = np.linspace(values.min(), values.max(), 7)
I'm using PySAL's Jenks natural breaks because my data is unevenly distributed:
from pysal.esda.mapclassify import Natural_Breaks as nb
# values is a pandas Series
breaks = nb( values, initial=150, k = 7)
This makes the map colors look good, but it messes up the legend:
So I tried assigning Jenks colors to the map, and equal intervals to the legend, but this happens:
The colorbar is assigned the right tick labels, but at the wrong position. So my question is: how can I get the colorbar to be equal intervals but the tick labels to be the Natural Breaks values in the right position?
Here's the pertinent code for the legend:
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from pysal.esda.mapclassify import Natural_Breaks as nb
values = pd.Series([71664, 65456, 60378, 50128, 46618, 44028, 42642, 41237, 35300, 34891, 34848, 33089, 29964, 25193, 25088, 23879, 23458, 18149, 16537, 15576, 15235, 14741, 11981, 11963, 11616, 10280, 9723, 9720, 9709, 9659, 9649, 9631, 9369, 8345, 8211, 7809, 7758, 7119, 7034, 6979, 6455, 5861, 5580, 5498, 5469, 5448, 5317, 4749, 4498, 4254, 4152, 3876, 3861, 3836, 3813, 3786, 3655, 3582, 3475, 2922, 2870, 2866, 2849, 2634, 2598, 2185, 1950, 1924, 1886, 1879, 1794, 1756, 1702, 1700, 1637, 1632, 1524, 1505, 1453, 1415, 1396, 1345, 1327, 1306, 1250, 1125, 1084, 1079, 1025, 976, 920, 903, 877, 868, 842, 815, 803, 799, 799, 792, 762, 725, 718, 714, 710, 660, 654, 647, 617, 616, 611, 600, 588, 572, 572, 567, 547, 536, 522, 482, 463, 439, 434, 428, 419, 415, 412, 410, 395, 390, 389, 386, 375, 374, 370, 345, 338, 325, 324, 285, 276, 272, 250, 236, 229, 227, 226, 216, 213, 209, 203, 200, 186, 186, 182, 182, 175, 173, 170, 169, 164, 164, 159, 155, 153, 148, 147, 140, 131, 129, 127, 127, 126, 124, 119, 117, 115, 114, 111, 109, 105, 103, 101, 97, 90, 89, 89, 85, 84, 77, 76, 74, 72, 71, 70, 70, 69, 62, 61, 61, 60, 57, 54, 53, 53, 51, 50, 50, 48, 44, 43, 42, 35, 34, 30, 29, 26, 23, 20, 19, 16, 15, 15, 12, 11, 9, 8, 8, 5, 3, 1])
num_colors = 7
# Jenks natural breaks for colormap
breaks = nb( values, initial=150, k = num_colors - 1)
bins = breaks.bins
# Orange-Red colormap
cm = plt.get_cmap('OrRd')
scheme = cm(1.*np.arange(num_colors)/num_colors)
fig = plt.figure(figsize=(19, 7))
ax_legend = fig.add_axes([0.35, 0.15, 0.3, 0.03], zorder=3)
cmap = mpl.colors.ListedColormap(scheme)
# Round legend ticks to nearest 100
legend_bins = np.around(bins, decimals = -2)
# Split colormap into equal intervals
legend_colors = np.linspace(values.min(), values.max(), num_colors)
cb = mpl.colorbar.ColorbarBase(ax_legend,
cmap=cmap,
ticks=legend_bins,
boundaries=legend_colors,
orientation='horizontal' )
After much wrestling, I found the answer. It's all about setting the ticks and boundaries parameters to the same thing, i.e. the bins. Then set the ticks to legend_colors.
The relevant bit to make it work is:
cb = mpl.colorbar.ColorbarBase(ax_legend,
cmap=cmap,
ticks=bins,
boundaries=bins,
orientation='horizontal' )
cb.set_ticks(legend_colors[1:])

Categories

Resources