How to plot 3D hist in python - python

I have dataset that contains accidents in a region with a year.
> Accident_ID Region Year
> 213 1 2003
> 234 2 2001
> 334 2 2004
> ..
years= [0.0, 1661.0, 1665.0, 1706.0, 1729.0, 1765.0, 1779.0, 1780.0, 1785.0, 1798.0, 1799.0, 1800.0, 1801.0, 1802.0, 1804.0, 1805.0, 1812.0, 1814.0, 1816.0, 1821.0, 1822.0, 1824.0, 1825.0, 1826.0, 1827.0, 1829.0, 1830.0, 1831.0, 1832.0, 1833.0, 1834.0, 1835.0, 1836.0, 1837.0, 1838.0, 1839.0, 1840.0, 1841.0, 1842.0, 1843.0, 1844.0, 1845.0, 1846.0, 1847.0, 1848.0, 1849.0, 1850.0, 1851.0, 1852.0, 1853.0, 1854.0, 1855.0, 1856.0, 1857.0, 1858.0, 1859.0, 1860.0, 1861.0, 1862.0, 1863.0, 1864.0, 1865.0, 1866.0, 1867.0, 1868.0, 1869.0, 1870.0, 1871.0, 1872.0, 1873.0, 1874.0, 1875.0, 1876.0, 1877.0, 1878.0, 1879.0, 1880.0, 1881.0, 1882.0, 1883.0, 1884.0, 1885.0, 1886.0, 1887.0, 1888.0, 1889.0, 1890.0, 1891.0, 1892.0, 1893.0, 1894.0, 1895.0, 1896.0, 1897.0, 1898.0, 1899.0, 1900.0, 1901.0, 1902.0, 1903.0, 1904.0, 1905.0, 1906.0, 1907.0, 1908.0, 1909.0, 1910.0, 1911.0, 1912.0, 1913.0, 1914.0, 1915.0, 1916.0, 1917.0, 1918.0, 1919.0, 1920.0, 1921.0, 1922.0, 1923.0, 1924.0, 1925.0, 1926.0, 1927.0, 1928.0, 1929.0, 1930.0, 1931.0, 1932.0, 1933.0, 1934.0, 1935.0, 1936.0, 1937.0, 1938.0, 1939.0, 1940.0, 1941.0, 1942.0, 1943.0, 1944.0, 1945.0, 1946.0, 1947.0, 1948.0, 1949.0, 1950.0, 1951.0, 1952.0, 1953.0, 1954.0, 1955.0, 1956.0, 1957.0, 1958.0, 1959.0, 1960.0, 1961.0, 1962.0, 1963.0, 1964.0, 1965.0, 1966.0, 1967.0, 1968.0, 1969.0, 1970.0, 1971.0, 1972.0, 1973.0, 1974.0, 1975.0, 1976.0, 1977.0, 1978.0, 1979.0, 1980.0, 1981.0, 1982.0, 1983.0, 1984.0, 1985.0, 1986.0, 1987.0, 1988.0, 1989.0, 1990.0, 1991.0, 1992.0, 1993.0, 1994.0, 1995.0, 1996.0, 1997.0, 1998.0, 1999.0, 2000.0, 2001.0, 2002.0, 2003.0, 2004.0, 2005.0, 2006.0, 2007.0, 2008.0, 2009.0, 2010.0, 2011.0, 2012.0, 2013.0]
Frequency_accidents_years= [44815, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 556, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 2, 5, 3, 20, 11, 6, 5, 1, 7, 6, 6, 2, 4, 6, 19, 9, 11, 10, 18, 18, 8, 9, 13, 20, 43, 7, 13, 11, 6, 13, 12, 6, 7, 9, 34, 3, 2, 3, 1, 7, 6, 4, 8, 11, 56, 18, 5, 4, 4, 16, 2, 1, 3, 3, 146, 49, 10, 7, 10, 22, 18, 14, 18, 17, 397, 46, 12, 14, 12, 53, 39, 18, 28, 25095, 9663, 26717, 131, 180, 268, 7660, 754, 641, 354, 873, 47024, 705, 720, 578, 598, 16547, 653, 516, 255, 296, 92079, 1161, 1175, 1634, 2111, 71121, 3158, 3289, 4355, 2136, 77654, 33007, 1253, 983, 365, 25554, 651, 665, 762, 968, 38485, 745, 326, 199, 176, 25048, 343, 368, 604, 753, 46674, 775, 683, 562, 645, 26992, 768, 959, 816, 922, 37271, 796, 915, 1101, 945, 19687, 618, 614, 620, 509, 17169, 497, 623, 853, 854, 9755, 662, 725, 999, 593, 5469, 554, 778, 1163, 1342, 3470, 3755, 3810, 3597, 3613, 3504, 2263, 3173, 2465, 2135, 2558, 3476, 3164, 2755, 3715, 4187, 4540, 4203, 4445, 6541, 5994, 4873, 4085, 2899, 1806, 1157, 1331, 1246, 424]
regions = xrange(1,100) // Can be generated this way eg: region1, region2 ..
I wanna plot these data in a 3D histogram to better analyse the dataset.
I wanna plot frequency of accident via region/Year
from collections import Counter
data.pandas("file.csv")
..
..
#Make 3D Hist
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot(xs=years, ys=Frequency_accidents_years, zs=regions, marker='o', linestyle='--', color='r',label="name")
plt.show()
I ended up having this problem.
I got this error:
line 49 ..
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/pyplot.py", line 561, in savefig
return fig.savefig(*args, **kwargs)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/figure.py", line 1421, in savefig
self.canvas.print_figure(*args, **kwargs)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/backend_bases.py", line 2220, in print_figure
**kwargs)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/backend_bases.py", line 1962, in print_png
return agg.print_png(*args, **kwargs)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/backends/backend_agg.py", line 505, in print_png
FigureCanvasAgg.draw(self)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/backends/backend_agg.py", line 451, in draw
self.figure.draw(self.renderer)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/figure.py", line 1034, in draw
func(*args)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 270, in draw
Axes.draw(self, renderer)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/matplotlib/axes.py", line 2086, in draw
a.draw(renderer)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/art3d.py", line 117, in draw
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/proj3d.py", line 194, in proj_transform
return proj_transform_vec(vec, M)
File "/Users/macbook/anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/proj3d.py", line 153, in proj_transform_vec
vecw = np.dot(M, vec)
ValueError: operands could not be broadcast together with shapes (210) (858312)

I've tried to come up with an example of how to plot a Matplotlib bar3d 3d histogram given your dataset structure using Panda's groupby as pointed out by Davidmh.
I used the following dataset based in what you posted above (accidents.csv):
Accident_ID,Region,Year
213,1,2003
214,1,2003
214,2,2008
213,2,2007
210,2,2007
210,3,2004
210,1,2004
213,1,2004
210,1,2004
This script reads the dataset, groups the data and builds a 3d histogram plot:
import matplotlib
import matplotlib.pyplot as plt
from pandas import read_csv
from mpl_toolkits.mplot3d import Axes3D
# Read CSV dataset file
df = read_csv('accidents.csv')
# Group by year and region
group_year_region = df.groupby(['Year', 'Region'])
group_keys = group_year_region.groups.keys()
# Get the years and regions series
xpos = map(lambda k: k[0] - 0.5, group_keys)
ypos = map(lambda k: k[1] - 0.5, group_keys)
zpos = [0] * len(xpos)
# Count number of accidents by (year, region) group
acc_by_year_region = group_year_region.count()['Accident_ID']
dx = 1
dy = 1
dz = [acc_by_year_region[key] for key in group_keys]
# Plot bar3d histogram
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color='b', zsort='average')
# Years should not be presented in exponential notation
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
ax.xaxis.set_major_formatter(x_formatter)
# Set labels and show plot
ax.set_xticks([k[0] for k in group_keys])
ax.set_yticks([k[1] for k in group_keys])
ax.set_zticks(dz)
ax.set_xlabel('Years')
ax.set_ylabel('Regions')
ax.set_zlabel('# Accidents')
plt.show()
The plot looks like this:
EDIT
There was a problem with the way the axes were built and the values were unordered. It is fixed now and the plot image has been updated.

The problem is that the shape of your data does not add up. You have three vectors, and they have to have the same length.
Running your example, years and Frequency_... have both 210 elements, but regions have only 99. I think what you want to do is to filter the accidents by region, for what you can use Panda's groupby.

Related

Use line.get_color in stackplot matplotlib

I have a chart with 30 categories (in the example there are only 6). Therefore, the use of the legend is not very convenient. It is possible to automatically add the category name, as done in the example with ax.text(color=line.get_color())?
I've tried using text inside the line, but sometimes the line is too narrow to fit the text that can be read.
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.rcParams["figure.figsize"] = (25,15)
population_by_continent = {'Category1': [0, 12, 45, 83, 237, 1071, 1349, 1863, 2517, 2941, 4876, 7539, 10358, 7951], 'Category6': [2, 3, 43, 69, 129, 561, 887, 1434, 2006, 2306, 4238, 6426, 7232, 5695], 'Category2': [2, 100, 329, 553, 877, 1870, 2663, 3372, 4243, 5558, 10140, 16572, 17875, 11932], 'Category3': [0, 32, 114, 123, 218, 483, 643, 808, 1037, 1188, 1915, 3007, 3059, 1900], 'Category5': [1, 70, 188, 321, 370, 467, 574, 722, 814, 884, 1347, 1916, 1925, 1634], 'Category1': [1, 13, 31, 107, 155, 311, 432, 502, 551, 529, 732, 1141, 1505, 1924], 'Category4': [2, 104, 331, 622, 1094, 2246, 2529, 2825, 3825, 4521, 9352, 15842, 22365, 17646]}
Years = [2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022]
mycolors = ["#590600", "#6aed95", "#5d1a8c", "#c6df51", "#2088ff", "#92cd49", "#9a0583", "#8aea7a",]
fig, ax = plt.subplots()
ax.stackplot(Years, population_by_continent.values(),
labels=population_by_continent.keys(), alpha=0.8,
colors=mycolors,
baseline = 'zero'
)
ax.legend(loc='upper left', fontsize=18,frameon=False)
plt.show()

How can I correct my Collatz Conjecture code?

I want to take the last number and multiply with the multiplier and add the increment. And put that number back into the list. I do not know how to put s into the list. As you can see it is "...8, 4, 2, 1], 4)" I want to put the 4 into the list.
def sfcollatz(n,divs=[2],mult=3,inc=1,maxSize=-1):
result = []
while n not in result and len(result)!=maxSize:
result.append(n)
d = next((d for d in divs if n%d==0),None)
n = (n*mult+inc) if not d else n//d
s = mult*result[-1]+inc
return result + ['...']*(n not in result),s
print(sfcollatz(27,[2],3,1,maxSize=200))
([27, 82, 41, 124, 62, 31, 94, 47, 142, 71, 214, 107, 322, 161, 484, 242, 121, 364, 182, 91, 274, 137, 412, 206, 103, 310, 155, 466, 233, 700, 350, 175, 526, 263, 790, 395, 1186, 593, 1780, 890, 445, 1336, 668, 334, 167, 502, 251, 754, 377, 1132, 566, 283, 850, 425, 1276, 638, 319, 958, 479, 1438, 719, 2158, 1079, 3238, 1619, 4858, 2429, 7288, 3644, 1822, 911, 2734, 1367, 4102, 2051, 6154, 3077, 9232, 4616, 2308, 1154, 577, 1732, 866, 433, 1300, 650, 325, 976, 488, 244, 122, 61, 184, 92, 46, 23, 70, 35, 106, 53, 160, 80, 40, 20, 10, 5, 16, 8, 4, 2, 1], 4)
I'm not exactly sure what you mean by "adding 4 into your array". Do you just mean you want it to go "...8, 4, 2, 1, 4]"? If so, you can just use the .append() function again. You also do not need to keep defining s in your while loop since it will always just take the final value:
def sfcollatz(n,divs,mult,inc,maxSize):
result = []
while n not in result and len(result)!=maxSize:
result.append(n)
d = next((d for d in divs if n%d==0),None)
n = (n*mult+inc) if not d else n//d
s = mult*result[-1]+inc
result.append(s)
return result
Your question was a little vague, but I hope this answers it. Let me know if you need any further help or clarification :)

Pyplot not plotting marker for detected peaks

I'm writing a Python script that plots a candlestick chart of with x markers indicating peak candlesticks. The used data is a series of USD/JPY rates read using pandas.read_csv() from a csv file provided by Oanda API. The result of pandas.DataFrame.head() is as follows:
time close open high low volume
0 2016/08/19 06:00:00 100.256 99.919 100.471 99.887 30965
1 2016/08/22 06:00:00 100.335 100.832 100.944 100.221 32920
2 2016/08/23 06:00:00 100.253 100.339 100.405 99.950 26069
3 2016/08/24 06:00:00 100.460 100.270 100.619 100.104 22340
4 2016/08/25 06:00:00 100.546 100.464 100.627 100.314 17224
While the candlestick chart itself is displayed properly (although it needs some foramtting), I don't see any markers on it.
What I expect is something like an example graph output shown on the scipy.signal.find_peaks document, only it is a candlestick chart instead of a line graph.
Here is my code:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import mpl_finance
df = pd.read_csv(sys.argv[1])
opens = df['open']
highs = df['high']
lows = df['low']
closes = df['close']
indices = find_peaks(highs)[0]
fig = plt.figure(figsize=(12, 4))
ax1 = fig.add_subplot(1, 1, 1)
mpl_finance.candlestick2_ohlc(ax1, opens, highs, lows, closes, width=4, colorup='k', colordown='r', alpha=0.75)
ax1.plot(x=indices, y=[highs[j] for j in indices], fmt="x", label="peak highs")
ax1.grid()
plt.show()
I suspected it's either the x or y parameter of ax1.plot() being empty, which is shown otherwise using pdb debugger:
-> ax1.plot(x=indices, y=[highs[j] for j in indices], fmt="x", label="peak highs")
(Pdb) indices
array([ 1, 10, 15, 18, 23, 25, 29, 34, 39, 47, 50, 59, 66,
70, 74, 76, 78, 81, 84, 87, 92, 95, 99, 101, 107, 113,
118, 126, 130, 138, 143, 145, 158, 161, 164, 170, 172, 176, 182,
186, 196, 203, 208, 215, 220, 222, 226, 230, 233, 237, 241, 246,
248, 256, 261, 263, 267, 282, 286, 290, 293, 296, 304, 306, 308,
310, 313, 316, 322, 331, 336, 342, 349, 352, 359, 367, 369, 373,
378, 382, 391, 395, 400, 403, 405, 411, 416, 422, 425, 428, 438,
441, 444, 447, 450, 454, 459, 466, 471, 473, 477, 485, 493, 497],
dtype=int32)
(Pdb) [highs[j] for j in indices]
[100.944, 104.33, 103.07, 103.367, 102.79799999999999, 101.258, 101.851, 104.17399999999999, 104.64299999999999, 104.882, 105.544, 106.95700000000001, 111.375, 113.911, 114.837, 114.78399999999999, 114.415, 116.134, 118.676, 118.251, 117.822, 118.624, 117.54299999999999, 116.89, 115.634, 115.38600000000001, 113.538, 114.962, 113.787, 114.765, 115.512, 115.2, 112.213, 111.48, 111.587, 109.23299999999999, 109.5, 111.79, 113.05799999999999, 114.39299999999999, 112.135, 111.721, 110.823, 111.8, 112.47399999999999, 112.935, 113.696, 114.505, 113.583, 112.429, 112.21600000000001, 110.99, 111.05799999999999, 110.95700000000001, 109.833, 109.85600000000001, 110.678, 112.72399999999999, 113.264, 113.20200000000001, 113.446, 112.834, 113.589, 114.10700000000001, 114.25, 114.462, 114.288, 114.742, 113.91799999999999, 111.70100000000001, 113.095, 113.758, 113.64399999999999, 113.398, 113.39299999999999, 111.49, 111.23200000000001, 109.77799999999999, 110.491, 109.79, 107.912, 107.685, 106.47, 107.06200000000001, 107.305, 106.65, 107.01799999999999, 107.499, 107.405, 107.788, 109.552, 110.044, 109.406, 110.02600000000001, 110.461, 111.40299999999999, 109.84899999999999, 110.275, 110.85799999999999, 110.91, 110.765, 111.14399999999999, 112.80799999999999, 113.18700000000001]
Could anyone give me a possible solution or an explanation of the cause?

How to add a legend to matplotlib scatter plot

I'm attempting to plot a PCA and one of the colours is label 1 and the other should be label 2. When I want to add a legend with ax1.legend() I only get the label for the blue dot or no label at all. How can I add the legend with the correct labels for both the blue and purple dots?
sns.set(style = 'darkgrid')
fig, ax1 = sns.plt.subplots()
x1, x2 = X_bar[:,0], X_bar[:,1]
ax1.scatter(x1, x2, 100, edgecolors='none', c = colors)
fig.set_figheight(8)
fig.set_figwidth(15)
It looks like you are plotting each point oscillating between two colours. As per the answer to this question subsampling every nth entry in a numpy array You can use numpys array slicing to plot two separate arrays, then do legend as normal.
For some sample data:
import numpy as np
import numpy.random as nprnd
import matplotlib.pyplot as plt
A = nprnd.randint(1000, size=100)
A.shape = (50,2)
x1, x2 = np.sort(A[:,0], axis=0), np.sort(A[:,1], axis=0)
x1
Out[50]:
array([ 46, 63, 84, 96, 118, 127, 137, 142, 181, 187, 187, 207, 210,
238, 238, 330, 334, 335, 346, 346, 350, 392, 400, 426, 467, 531,
550, 567, 569, 572, 583, 625, 637, 661, 671, 677, 698, 713, 777,
796, 837, 850, 866, 868, 874, 890, 919, 972, 992, 993])
x2
Out[51]:
array([ 2, 44, 49, 51, 72, 84, 86, 118, 120, 133, 150, 155, 156,
159, 199, 202, 250, 281, 289, 317, 317, 386, 405, 414, 427, 461,
507, 510, 543, 552, 553, 555, 559, 576, 618, 622, 633, 647, 665,
672, 682, 685, 745, 767, 776, 802, 808, 813, 847, 973])
labels=['blue','red']
fig, ax1 = plt.subplots()
ax1.scatter(x1[0::2], x2[0::2], 100, edgecolors='none', c='red', label = 'red')
ax1.scatter(x1[1::2], x2[1::2], 100, edgecolors='none', c='black', label = 'black')
plt.legend()
plt.show()
For your code, you can do:
sns.set(style = 'darkgrid')
fig, ax1 = sns.plt.subplots()
x1, x2 = X_bar[:,0], X_bar[:,1]
ax1.scatter(x1[0::2], x2[0::2], 100, edgecolors='none', c = colors[0], label='one')
ax1.scatter(x1[1::2], x2[1::2], 100, edgecolors='none', c = colors[1], label='two')
fig.set_figheight(8)
fig.set_figwidth(15)
plt.legend()

Matplotlib: colorbar breaks when using PySAL natural breaks

I'm making a choropleth map based on this tutorial.
But instead of splitting the data into equal intervals, like this:
bins = np.linspace(values.min(), values.max(), 7)
I'm using PySAL's Jenks natural breaks because my data is unevenly distributed:
from pysal.esda.mapclassify import Natural_Breaks as nb
# values is a pandas Series
breaks = nb( values, initial=150, k = 7)
This makes the map colors look good, but it messes up the legend:
So I tried assigning Jenks colors to the map, and equal intervals to the legend, but this happens:
The colorbar is assigned the right tick labels, but at the wrong position. So my question is: how can I get the colorbar to be equal intervals but the tick labels to be the Natural Breaks values in the right position?
Here's the pertinent code for the legend:
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from pysal.esda.mapclassify import Natural_Breaks as nb
values = pd.Series([71664, 65456, 60378, 50128, 46618, 44028, 42642, 41237, 35300, 34891, 34848, 33089, 29964, 25193, 25088, 23879, 23458, 18149, 16537, 15576, 15235, 14741, 11981, 11963, 11616, 10280, 9723, 9720, 9709, 9659, 9649, 9631, 9369, 8345, 8211, 7809, 7758, 7119, 7034, 6979, 6455, 5861, 5580, 5498, 5469, 5448, 5317, 4749, 4498, 4254, 4152, 3876, 3861, 3836, 3813, 3786, 3655, 3582, 3475, 2922, 2870, 2866, 2849, 2634, 2598, 2185, 1950, 1924, 1886, 1879, 1794, 1756, 1702, 1700, 1637, 1632, 1524, 1505, 1453, 1415, 1396, 1345, 1327, 1306, 1250, 1125, 1084, 1079, 1025, 976, 920, 903, 877, 868, 842, 815, 803, 799, 799, 792, 762, 725, 718, 714, 710, 660, 654, 647, 617, 616, 611, 600, 588, 572, 572, 567, 547, 536, 522, 482, 463, 439, 434, 428, 419, 415, 412, 410, 395, 390, 389, 386, 375, 374, 370, 345, 338, 325, 324, 285, 276, 272, 250, 236, 229, 227, 226, 216, 213, 209, 203, 200, 186, 186, 182, 182, 175, 173, 170, 169, 164, 164, 159, 155, 153, 148, 147, 140, 131, 129, 127, 127, 126, 124, 119, 117, 115, 114, 111, 109, 105, 103, 101, 97, 90, 89, 89, 85, 84, 77, 76, 74, 72, 71, 70, 70, 69, 62, 61, 61, 60, 57, 54, 53, 53, 51, 50, 50, 48, 44, 43, 42, 35, 34, 30, 29, 26, 23, 20, 19, 16, 15, 15, 12, 11, 9, 8, 8, 5, 3, 1])
num_colors = 7
# Jenks natural breaks for colormap
breaks = nb( values, initial=150, k = num_colors - 1)
bins = breaks.bins
# Orange-Red colormap
cm = plt.get_cmap('OrRd')
scheme = cm(1.*np.arange(num_colors)/num_colors)
fig = plt.figure(figsize=(19, 7))
ax_legend = fig.add_axes([0.35, 0.15, 0.3, 0.03], zorder=3)
cmap = mpl.colors.ListedColormap(scheme)
# Round legend ticks to nearest 100
legend_bins = np.around(bins, decimals = -2)
# Split colormap into equal intervals
legend_colors = np.linspace(values.min(), values.max(), num_colors)
cb = mpl.colorbar.ColorbarBase(ax_legend,
cmap=cmap,
ticks=legend_bins,
boundaries=legend_colors,
orientation='horizontal' )
After much wrestling, I found the answer. It's all about setting the ticks and boundaries parameters to the same thing, i.e. the bins. Then set the ticks to legend_colors.
The relevant bit to make it work is:
cb = mpl.colorbar.ColorbarBase(ax_legend,
cmap=cmap,
ticks=bins,
boundaries=bins,
orientation='horizontal' )
cb.set_ticks(legend_colors[1:])

Categories

Resources