Use line.get_color in stackplot matplotlib - python

I have a chart with 30 categories (in the example there are only 6). Therefore, the use of the legend is not very convenient. It is possible to automatically add the category name, as done in the example with ax.text(color=line.get_color())?
I've tried using text inside the line, but sometimes the line is too narrow to fit the text that can be read.
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.rcParams["figure.figsize"] = (25,15)
population_by_continent = {'Category1': [0, 12, 45, 83, 237, 1071, 1349, 1863, 2517, 2941, 4876, 7539, 10358, 7951], 'Category6': [2, 3, 43, 69, 129, 561, 887, 1434, 2006, 2306, 4238, 6426, 7232, 5695], 'Category2': [2, 100, 329, 553, 877, 1870, 2663, 3372, 4243, 5558, 10140, 16572, 17875, 11932], 'Category3': [0, 32, 114, 123, 218, 483, 643, 808, 1037, 1188, 1915, 3007, 3059, 1900], 'Category5': [1, 70, 188, 321, 370, 467, 574, 722, 814, 884, 1347, 1916, 1925, 1634], 'Category1': [1, 13, 31, 107, 155, 311, 432, 502, 551, 529, 732, 1141, 1505, 1924], 'Category4': [2, 104, 331, 622, 1094, 2246, 2529, 2825, 3825, 4521, 9352, 15842, 22365, 17646]}
Years = [2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022]
mycolors = ["#590600", "#6aed95", "#5d1a8c", "#c6df51", "#2088ff", "#92cd49", "#9a0583", "#8aea7a",]
fig, ax = plt.subplots()
ax.stackplot(Years, population_by_continent.values(),
labels=population_by_continent.keys(), alpha=0.8,
colors=mycolors,
baseline = 'zero'
)
ax.legend(loc='upper left', fontsize=18,frameon=False)
plt.show()

Related

How to create a histogram from counts with bins spaced every 0.1

I have the following dataframe:
df = {'count1': [2.2336, 2.2454, 2.2538, 2.2716999999999996, 2.2798000000000003, 2.2843, 2.2906, 2.2969, 2.3223000000000003, 2.3282, 2.3356999999999997, 2.3544, 2.3651999999999997, 2.3727, 2.3775, 2.3823000000000003, 2.392, 2.4051, 2.4092, 2.4133, 2.4168000000000003, 2.4175, 2.4209, 2.4392, 2.4476, 2.456, 2.461, 2.4723, 2.4776, 2.4882, 2.4989, 2.5095, 2.5221999999999998, 2.5318, 2.5422, 2.5494, 2.559, 2.5654, 2.5814, 2.5878, 2.6238, 2.6178000000000003, 2.624, 2.6303, 2.6366, 2.6425, 2.6481999999999997, 2.6525, 2.6553, 2.663, 2.6712, 2.6898, 2.7051, 2.7144, 2.727, 2.7416, 2.7472, 2.7512, 2.7557, 2.7574, 2.7594000000000003, 2.7636, 2.7699000000000003, 2.7761, 2.7809, 2.7855, 2.7902, 2.7948000000000004, 2.7995, 2.8043, 2.815, 2.8249, 2.8352, 2.8455, 2.8708, 2.8874, 2.9004000000000003, 2.9301, 2.9399, 2.9513000000000003, 2.9634, 2.9745999999999997, 2.9852, 2.9959000000000002, 3.0037, 3.0093, 3.015, 3.0184, 3.0206, 3.0225, 3.0245, 3.0264, 3.0282, 3.0305999999999997, 3.0331, 3.0334, 3.0361, 3.0388, 3.0418000000000003, 3.0443000000000002, 3.0463, 3.0464, 3.0481, 3.0496999999999996, 3.0514, 3.0530999999999997, 3.0544000000000002, 3.0556, 3.0569, 3.0581, 3.0623, 3.0627, 3.0633000000000004, 3.0638, 3.0643000000000002, 3.0648, 3.0652, 3.0656999999999996, 3.0663, 3.0675, 3.0682, 3.0688, 3.0695, 3.0702, 3.0721, 3.0741, 3.0761, 3.078, 3.08, 3.082, 3.0839000000000003, 3.0859, 3.0879000000000003, 3.0898000000000003, 3.0918, 3.0938000000000003, 3.0994, 3.1050999999999997, 3.1144000000000003, 3.1613, 3.1649000000000003, 3.1752, 3.1869, 3.1899, 3.1925, 3.1976, 3.2001, 3.2051999999999996, 3.2098, 3.2123000000000004],
'count2': [3144, 3944, 7888, 4428, 68874, 5480, 56697, 20560, 8744, 91190, 352, 924, 1308611, 480, 51146, 170373, 58792, 11424, 1288673, 1845105, 401464, 657930, 1361172, 199373, 19753, 39082, 776, 7533, 9289, 36731, 53865, 100140, 59274, 35740, 2648, 144998, 78616, 848241, 34579, 216591, 22512, 4024, 17168, 1552, 13760, 8344, 65589, 43104, 44672, 917115, 16256, 4168, 29679, 22571, 7720, 452, 8836, 6888, 18578, 5148, 9289, 442, 214, 485, 3164, 1101, 1010, 9048, 293, 1628, 960, 517, 2362, 1262, 1524, 1173, 1348, 1288, 25568, 8416, 5792, 4944, 504, 4696, 2336, 458, 453, 1220, 1149, 6688, 6956, 7324, 7100, 7784, 5650, 5076, 5336, 6792, 5212, 4592, 5260, 1279, 654, 842, 990, 782, 1412, 1363, 935, 996, 775, 1471, 1525, 1398, 1097, 1082, 1668, 1007, 497, 598, 645, 698, 541, 504, 549, 540, 1568, 514, 578, 2906, 4360, 3916, 11944, 1434, 1589, 732, 641, 477, 307, 1884, 3232, 2408, 1016, 332, 139, 344, 4784, 1784, 1324, 204]}
df = pd.DataFrame(df)
And I want to plot a barplot with it, where the x axis is count1 and the y axis count2, with bins spaced every 0.1 intervals.
I used this:
plt.bar(x=df['count1'], y=df['count2'], width=0.1)
But it returns me this error:
TypeError: bar() missing 1 required positional argument: 'height'
I'm trying to replicate an R code:
ggplot(df, aes(x= count1,
y= count2)) +
geom_col() +
ylim(0, 2000000) +
scale_x_binned()
That generates the following graph:
To get a histogram from values and counts, you can use the weights= parameter of plt.hist.
To create bins with a width of 0.1, you can use np.arange(...,..., 0.1).
The rwidth=0.9 parameter makes the bars a bit narrower.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
df = {'count1': [2.2336, 2.2454, 2.2538, 2.2716999999999996, 2.2798000000000003, 2.2843, 2.2906, 2.2969, 2.3223000000000003, 2.3282, 2.3356999999999997, 2.3544, 2.3651999999999997, 2.3727, 2.3775, 2.3823000000000003, 2.392, 2.4051, 2.4092, 2.4133, 2.4168000000000003, 2.4175, 2.4209, 2.4392, 2.4476, 2.456, 2.461, 2.4723, 2.4776, 2.4882, 2.4989, 2.5095, 2.5221999999999998, 2.5318, 2.5422, 2.5494, 2.559, 2.5654, 2.5814, 2.5878, 2.6238, 2.6178000000000003, 2.624, 2.6303, 2.6366, 2.6425, 2.6481999999999997, 2.6525, 2.6553, 2.663, 2.6712, 2.6898, 2.7051, 2.7144, 2.727, 2.7416, 2.7472, 2.7512, 2.7557, 2.7574, 2.7594000000000003, 2.7636, 2.7699000000000003, 2.7761, 2.7809, 2.7855, 2.7902, 2.7948000000000004, 2.7995, 2.8043, 2.815, 2.8249, 2.8352, 2.8455, 2.8708, 2.8874, 2.9004000000000003, 2.9301, 2.9399, 2.9513000000000003, 2.9634, 2.9745999999999997, 2.9852, 2.9959000000000002, 3.0037, 3.0093, 3.015, 3.0184, 3.0206, 3.0225, 3.0245, 3.0264, 3.0282, 3.0305999999999997, 3.0331, 3.0334, 3.0361, 3.0388, 3.0418000000000003, 3.0443000000000002, 3.0463, 3.0464, 3.0481, 3.0496999999999996, 3.0514, 3.0530999999999997, 3.0544000000000002, 3.0556, 3.0569, 3.0581, 3.0623, 3.0627, 3.0633000000000004, 3.0638, 3.0643000000000002, 3.0648, 3.0652, 3.0656999999999996, 3.0663, 3.0675, 3.0682, 3.0688, 3.0695, 3.0702, 3.0721, 3.0741, 3.0761, 3.078, 3.08, 3.082, 3.0839000000000003, 3.0859, 3.0879000000000003, 3.0898000000000003, 3.0918, 3.0938000000000003, 3.0994, 3.1050999999999997, 3.1144000000000003, 3.1613, 3.1649000000000003, 3.1752, 3.1869, 3.1899, 3.1925, 3.1976, 3.2001, 3.2051999999999996, 3.2098, 3.2123000000000004],
'count2': [3144, 3944, 7888, 4428, 68874, 5480, 56697, 20560, 8744, 91190, 352, 924, 1308611, 480, 51146, 170373, 58792, 11424, 1288673, 1845105, 401464, 657930, 1361172, 199373, 19753, 39082, 776, 7533, 9289, 36731, 53865, 100140, 59274, 35740, 2648, 144998, 78616, 848241, 34579, 216591, 22512, 4024, 17168, 1552, 13760, 8344, 65589, 43104, 44672, 917115, 16256, 4168, 29679, 22571, 7720, 452, 8836, 6888, 18578, 5148, 9289, 442, 214, 485, 3164, 1101, 1010, 9048, 293, 1628, 960, 517, 2362, 1262, 1524, 1173, 1348, 1288, 25568, 8416, 5792, 4944, 504, 4696, 2336, 458, 453, 1220, 1149, 6688, 6956, 7324, 7100, 7784, 5650, 5076, 5336, 6792, 5212, 4592, 5260, 1279, 654, 842, 990, 782, 1412, 1363, 935, 996, 775, 1471, 1525, 1398, 1097, 1082, 1668, 1007, 497, 598, 645, 698, 541, 504, 549, 540, 1568, 514, 578, 2906, 4360, 3916, 11944, 1434, 1589, 732, 641, 477, 307, 1884, 3232, 2408, 1016, 332, 139, 344, 4784, 1784, 1324, 204]}
df = pd.DataFrame(df)
bin_start = np.trunc(df['count1'].min() * 10) / 10
bin_end = df['count1'].max() + 0.1
plt.style.use('ggplot')
plt.hist(x=df['count1'], weights=df['count2'], bins=np.arange(bin_start, bin_end, 0.1), rwidth=0.9)
plt.gca().get_yaxis().get_major_formatter().set_scientific(False)
plt.xlabel('count1')
plt.ylabel('count2')
plt.tight_layout()
plt.show()

How to implement different sequences in shell sort in python?

Hi I have the following code for implementing Shell sort in Python. How can I implement the following sequences in Shell sort using the code below (Note this is not the list I want to sort) :
1, 4, 13, 40, 121, 364, 1093, 3280, 9841, 29524 (Knuth’s sequence)
1, 5, 17, 53, 149, 373, 1123, 3371, 10111, 30341
1, 10, 30, 60, 120, 360, 1080, 3240, 9720, 29160
interval = n // 2
while interval > 0:
for i in range(interval, n):
temp = array[i]
j = i
while j >= interval and array[j - interval] > temp:
array[j] = array[j - interval]
j -= interval
array[j] = temp
interval //= 2
You could modify the pseudo-code provided in the Wikipedia article for Shellsort to take in the gap sequence as a parameter:
from random import choices
from timeit import timeit
RAND_SEQUENCE_SIZE = 500
GAP_SEQUENCES = {
'CIURA_A102549': [701, 301, 132, 57, 23, 10, 4, 1],
'KNUTH_A003462': [29524, 9841, 3280, 1093, 364, 121, 40, 13, 4, 1],
'SPACED_OUT_PRIME_GAPS': [30341, 10111, 3371, 1123, 373, 149, 53, 17, 5, 1],
'SPACED_OUT_EVEN_GAPS': [29160, 9720, 3240, 1080, 360, 120, 60, 30, 10, 1],
}
def shell_sort(seq: list[int], gap_sequence: list[int]) -> None:
n = len(seq)
# Start with the largest gap and work down to a gap of 1. Similar to
# insertion sort but instead of 1, gap is being used in each step.
for gap in gap_sequence:
# Do a gapped insertion sort for every element in gaps.
# Each gap sort includes (0..gap-1) offset interleaved sorting.
for offset in range(gap):
for i in range(offset, n, gap):
# Save seq[i] in temp and make a hole at position i.
temp = seq[i]
# Shift earlier gap-sorted elements up until the correct
# location for seq[i] is found.
j = i
while j >= gap and seq[j - gap] > temp:
seq[j] = seq[j - gap]
j -= gap
# Put temp (the original seq[i]) in its correct location.
seq[j] = temp
def main() -> None:
seq = choices(population=range(1000), k=RAND_SEQUENCE_SIZE)
print(f'{seq = }')
print(f'{len(seq) = }')
for name, gap_sequence in GAP_SEQUENCES.items():
print(f'Shell sort using {name} gap sequence: {gap_sequence}')
print(f'Time taken to sort 100 times: {timeit(lambda: shell_sort(seq.copy(), gap_sequence), number=100)} seconds')
if __name__ == '__main__':
main()
Example Output:
seq = [331, 799, 153, 700, 373, 38, 203, 535, 894, 500, 922, 939, 507, 506, 89, 40, 442, 108, 112, 359, 280, 946, 395, 708, 140, 435, 588, 306, 202, 23, 6, 189, 570, 600, 857, 949, 606, 617, 556, 863, 521, 776, 436, 801, 501, 588, 927, 279, 210, 72, 460, 52, 340, 632, 385, 965, 730, 360, 88, 216, 991, 520, 74, 112, 770, 853, 483, 787, 229, 812, 259, 349, 967, 227, 957, 728, 780, 51, 604, 748, 3, 679, 33, 488, 130, 203, 493, 471, 397, 53, 49, 172, 7, 306, 613, 519, 575, 64, 168, 161, 376, 903, 338, 800, 58, 729, 421, 238, 967, 294, 967, 218, 456, 823, 649, 569, 144, 103, 970, 780, 859, 719, 15, 536, 263, 917, 0, 54, 370, 703, 911, 518, 78, 41, 106, 452, 355, 571, 249, 58, 274, 327, 500, 341, 743, 536, 432, 799, 597, 681, 301, 856, 219, 63, 653, 680, 891, 725, 537, 673, 815, 504, 720, 573, 60, 91, 909, 892, 964, 119, 793, 540, 303, 538, 130, 717, 755, 968, 46, 229, 837, 398, 182, 303, 99, 808, 56, 780, 415, 33, 511, 771, 875, 593, 120, 727, 505, 905, 619, 295, 958, 566, 8, 291, 811, 529, 789, 523, 545, 5, 631, 28, 107, 292, 831, 657, 952, 239, 814, 862, 912, 2, 147, 750, 132, 528, 408, 916, 718, 261, 488, 621, 261, 963, 880, 625, 151, 982, 819, 749, 224, 572, 690, 766, 278, 417, 248, 987, 664, 515, 691, 940, 860, 172, 898, 321, 381, 662, 293, 354, 642, 219, 133, 133, 854, 162, 254, 816, 630, 21, 577, 486, 792, 731, 714, 581, 633, 794, 120, 386, 874, 177, 652, 159, 264, 414, 417, 730, 728, 716, 973, 688, 106, 345, 153, 909, 382, 505, 721, 363, 230, 588, 765, 340, 142, 549, 558, 189, 547, 728, 974, 468, 182, 255, 637, 317, 40, 775, 696, 135, 985, 884, 131, 797, 84, 89, 962, 810, 520, 843, 24, 400, 717, 834, 170, 681, 333, 68, 159, 688, 422, 198, 621, 386, 391, 839, 283, 167, 655, 314, 820, 432, 412, 181, 440, 864, 828, 217, 491, 593, 298, 885, 831, 535, 92, 305, 510, 90, 949, 461, 627, 851, 606, 280, 413, 624, 916, 16, 517, 700, 776, 323, 161, 329, 25, 868, 258, 97, 219, 620, 69, 24, 794, 981, 361, 691, 20, 90, 825, 442, 531, 562, 240, 0, 440, 418, 338, 526, 34, 230, 381, 598, 734, 925, 209, 231, 980, 122, 374, 752, 144, 105, 920, 780, 828, 948, 515, 443, 810, 81, 303, 751, 779, 516, 394, 455, 116, 448, 652, 293, 327, 367, 793, 47, 946, 653, 927, 910, 583, 845, 442, 989, 393, 490, 564, 54, 656, 689, 626, 531, 941, 575, 628, 865, 705, 219, 42, 19, 10, 155, 436, 319, 510, 520, 869, 101, 918, 170, 826, 146, 389, 200, 992, 404, 982, 889, 818, 684, 524, 642, 991, 973, 561, 104, 418, 207, 963, 192, 410, 33]
len(seq) = 500
Shell sort using CIURA_A102549 gap sequence: [701, 301, 132, 57, 23, 10, 4, 1]
Time taken to sort 100 times: 0.06717020808719099 seconds
Shell sort using KNUTH_A003462 gap sequence: [29524, 9841, 3280, 1093, 364, 121, 40, 13, 4, 1]
Time taken to sort 100 times: 0.34870366705581546 seconds
Shell sort using SPACED_OUT_PRIME_GAPS gap sequence: [30341, 10111, 3371, 1123, 373, 149, 53, 17, 5, 1]
Time taken to sort 100 times: 0.3563524999190122 seconds
Shell sort using SPACED_OUT_EVEN_GAPS gap sequence: [29160, 9720, 3240, 1080, 360, 120, 60, 30, 10, 1]
Time taken to sort 100 times: 0.38147866702638566 seconds

Pyplot not plotting marker for detected peaks

I'm writing a Python script that plots a candlestick chart of with x markers indicating peak candlesticks. The used data is a series of USD/JPY rates read using pandas.read_csv() from a csv file provided by Oanda API. The result of pandas.DataFrame.head() is as follows:
time close open high low volume
0 2016/08/19 06:00:00 100.256 99.919 100.471 99.887 30965
1 2016/08/22 06:00:00 100.335 100.832 100.944 100.221 32920
2 2016/08/23 06:00:00 100.253 100.339 100.405 99.950 26069
3 2016/08/24 06:00:00 100.460 100.270 100.619 100.104 22340
4 2016/08/25 06:00:00 100.546 100.464 100.627 100.314 17224
While the candlestick chart itself is displayed properly (although it needs some foramtting), I don't see any markers on it.
What I expect is something like an example graph output shown on the scipy.signal.find_peaks document, only it is a candlestick chart instead of a line graph.
Here is my code:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import mpl_finance
df = pd.read_csv(sys.argv[1])
opens = df['open']
highs = df['high']
lows = df['low']
closes = df['close']
indices = find_peaks(highs)[0]
fig = plt.figure(figsize=(12, 4))
ax1 = fig.add_subplot(1, 1, 1)
mpl_finance.candlestick2_ohlc(ax1, opens, highs, lows, closes, width=4, colorup='k', colordown='r', alpha=0.75)
ax1.plot(x=indices, y=[highs[j] for j in indices], fmt="x", label="peak highs")
ax1.grid()
plt.show()
I suspected it's either the x or y parameter of ax1.plot() being empty, which is shown otherwise using pdb debugger:
-> ax1.plot(x=indices, y=[highs[j] for j in indices], fmt="x", label="peak highs")
(Pdb) indices
array([ 1, 10, 15, 18, 23, 25, 29, 34, 39, 47, 50, 59, 66,
70, 74, 76, 78, 81, 84, 87, 92, 95, 99, 101, 107, 113,
118, 126, 130, 138, 143, 145, 158, 161, 164, 170, 172, 176, 182,
186, 196, 203, 208, 215, 220, 222, 226, 230, 233, 237, 241, 246,
248, 256, 261, 263, 267, 282, 286, 290, 293, 296, 304, 306, 308,
310, 313, 316, 322, 331, 336, 342, 349, 352, 359, 367, 369, 373,
378, 382, 391, 395, 400, 403, 405, 411, 416, 422, 425, 428, 438,
441, 444, 447, 450, 454, 459, 466, 471, 473, 477, 485, 493, 497],
dtype=int32)
(Pdb) [highs[j] for j in indices]
[100.944, 104.33, 103.07, 103.367, 102.79799999999999, 101.258, 101.851, 104.17399999999999, 104.64299999999999, 104.882, 105.544, 106.95700000000001, 111.375, 113.911, 114.837, 114.78399999999999, 114.415, 116.134, 118.676, 118.251, 117.822, 118.624, 117.54299999999999, 116.89, 115.634, 115.38600000000001, 113.538, 114.962, 113.787, 114.765, 115.512, 115.2, 112.213, 111.48, 111.587, 109.23299999999999, 109.5, 111.79, 113.05799999999999, 114.39299999999999, 112.135, 111.721, 110.823, 111.8, 112.47399999999999, 112.935, 113.696, 114.505, 113.583, 112.429, 112.21600000000001, 110.99, 111.05799999999999, 110.95700000000001, 109.833, 109.85600000000001, 110.678, 112.72399999999999, 113.264, 113.20200000000001, 113.446, 112.834, 113.589, 114.10700000000001, 114.25, 114.462, 114.288, 114.742, 113.91799999999999, 111.70100000000001, 113.095, 113.758, 113.64399999999999, 113.398, 113.39299999999999, 111.49, 111.23200000000001, 109.77799999999999, 110.491, 109.79, 107.912, 107.685, 106.47, 107.06200000000001, 107.305, 106.65, 107.01799999999999, 107.499, 107.405, 107.788, 109.552, 110.044, 109.406, 110.02600000000001, 110.461, 111.40299999999999, 109.84899999999999, 110.275, 110.85799999999999, 110.91, 110.765, 111.14399999999999, 112.80799999999999, 113.18700000000001]
Could anyone give me a possible solution or an explanation of the cause?

How to add a legend to matplotlib scatter plot

I'm attempting to plot a PCA and one of the colours is label 1 and the other should be label 2. When I want to add a legend with ax1.legend() I only get the label for the blue dot or no label at all. How can I add the legend with the correct labels for both the blue and purple dots?
sns.set(style = 'darkgrid')
fig, ax1 = sns.plt.subplots()
x1, x2 = X_bar[:,0], X_bar[:,1]
ax1.scatter(x1, x2, 100, edgecolors='none', c = colors)
fig.set_figheight(8)
fig.set_figwidth(15)
It looks like you are plotting each point oscillating between two colours. As per the answer to this question subsampling every nth entry in a numpy array You can use numpys array slicing to plot two separate arrays, then do legend as normal.
For some sample data:
import numpy as np
import numpy.random as nprnd
import matplotlib.pyplot as plt
A = nprnd.randint(1000, size=100)
A.shape = (50,2)
x1, x2 = np.sort(A[:,0], axis=0), np.sort(A[:,1], axis=0)
x1
Out[50]:
array([ 46, 63, 84, 96, 118, 127, 137, 142, 181, 187, 187, 207, 210,
238, 238, 330, 334, 335, 346, 346, 350, 392, 400, 426, 467, 531,
550, 567, 569, 572, 583, 625, 637, 661, 671, 677, 698, 713, 777,
796, 837, 850, 866, 868, 874, 890, 919, 972, 992, 993])
x2
Out[51]:
array([ 2, 44, 49, 51, 72, 84, 86, 118, 120, 133, 150, 155, 156,
159, 199, 202, 250, 281, 289, 317, 317, 386, 405, 414, 427, 461,
507, 510, 543, 552, 553, 555, 559, 576, 618, 622, 633, 647, 665,
672, 682, 685, 745, 767, 776, 802, 808, 813, 847, 973])
labels=['blue','red']
fig, ax1 = plt.subplots()
ax1.scatter(x1[0::2], x2[0::2], 100, edgecolors='none', c='red', label = 'red')
ax1.scatter(x1[1::2], x2[1::2], 100, edgecolors='none', c='black', label = 'black')
plt.legend()
plt.show()
For your code, you can do:
sns.set(style = 'darkgrid')
fig, ax1 = sns.plt.subplots()
x1, x2 = X_bar[:,0], X_bar[:,1]
ax1.scatter(x1[0::2], x2[0::2], 100, edgecolors='none', c = colors[0], label='one')
ax1.scatter(x1[1::2], x2[1::2], 100, edgecolors='none', c = colors[1], label='two')
fig.set_figheight(8)
fig.set_figwidth(15)
plt.legend()

Matplotlib: colorbar breaks when using PySAL natural breaks

I'm making a choropleth map based on this tutorial.
But instead of splitting the data into equal intervals, like this:
bins = np.linspace(values.min(), values.max(), 7)
I'm using PySAL's Jenks natural breaks because my data is unevenly distributed:
from pysal.esda.mapclassify import Natural_Breaks as nb
# values is a pandas Series
breaks = nb( values, initial=150, k = 7)
This makes the map colors look good, but it messes up the legend:
So I tried assigning Jenks colors to the map, and equal intervals to the legend, but this happens:
The colorbar is assigned the right tick labels, but at the wrong position. So my question is: how can I get the colorbar to be equal intervals but the tick labels to be the Natural Breaks values in the right position?
Here's the pertinent code for the legend:
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from pysal.esda.mapclassify import Natural_Breaks as nb
values = pd.Series([71664, 65456, 60378, 50128, 46618, 44028, 42642, 41237, 35300, 34891, 34848, 33089, 29964, 25193, 25088, 23879, 23458, 18149, 16537, 15576, 15235, 14741, 11981, 11963, 11616, 10280, 9723, 9720, 9709, 9659, 9649, 9631, 9369, 8345, 8211, 7809, 7758, 7119, 7034, 6979, 6455, 5861, 5580, 5498, 5469, 5448, 5317, 4749, 4498, 4254, 4152, 3876, 3861, 3836, 3813, 3786, 3655, 3582, 3475, 2922, 2870, 2866, 2849, 2634, 2598, 2185, 1950, 1924, 1886, 1879, 1794, 1756, 1702, 1700, 1637, 1632, 1524, 1505, 1453, 1415, 1396, 1345, 1327, 1306, 1250, 1125, 1084, 1079, 1025, 976, 920, 903, 877, 868, 842, 815, 803, 799, 799, 792, 762, 725, 718, 714, 710, 660, 654, 647, 617, 616, 611, 600, 588, 572, 572, 567, 547, 536, 522, 482, 463, 439, 434, 428, 419, 415, 412, 410, 395, 390, 389, 386, 375, 374, 370, 345, 338, 325, 324, 285, 276, 272, 250, 236, 229, 227, 226, 216, 213, 209, 203, 200, 186, 186, 182, 182, 175, 173, 170, 169, 164, 164, 159, 155, 153, 148, 147, 140, 131, 129, 127, 127, 126, 124, 119, 117, 115, 114, 111, 109, 105, 103, 101, 97, 90, 89, 89, 85, 84, 77, 76, 74, 72, 71, 70, 70, 69, 62, 61, 61, 60, 57, 54, 53, 53, 51, 50, 50, 48, 44, 43, 42, 35, 34, 30, 29, 26, 23, 20, 19, 16, 15, 15, 12, 11, 9, 8, 8, 5, 3, 1])
num_colors = 7
# Jenks natural breaks for colormap
breaks = nb( values, initial=150, k = num_colors - 1)
bins = breaks.bins
# Orange-Red colormap
cm = plt.get_cmap('OrRd')
scheme = cm(1.*np.arange(num_colors)/num_colors)
fig = plt.figure(figsize=(19, 7))
ax_legend = fig.add_axes([0.35, 0.15, 0.3, 0.03], zorder=3)
cmap = mpl.colors.ListedColormap(scheme)
# Round legend ticks to nearest 100
legend_bins = np.around(bins, decimals = -2)
# Split colormap into equal intervals
legend_colors = np.linspace(values.min(), values.max(), num_colors)
cb = mpl.colorbar.ColorbarBase(ax_legend,
cmap=cmap,
ticks=legend_bins,
boundaries=legend_colors,
orientation='horizontal' )
After much wrestling, I found the answer. It's all about setting the ticks and boundaries parameters to the same thing, i.e. the bins. Then set the ticks to legend_colors.
The relevant bit to make it work is:
cb = mpl.colorbar.ColorbarBase(ax_legend,
cmap=cmap,
ticks=bins,
boundaries=bins,
orientation='horizontal' )
cb.set_ticks(legend_colors[1:])

Categories

Resources