How to set a shared secondary axes using subplots in matplotlib.
Here is the minimal code to display the issue:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
def countour_every(ax, every, x_data, y_data,
color='black', linestyle='-', marker='o', **kwargs):
"""Draw a line with countour marks at each every points"""
line, = ax.plot(x_data, y_data, linestyle)
return line
def prettify_axes(ax, data):
"""Makes my plot pretty"""
if 'title' in data:
ax.set_title(data['title'])
if 'y_lim' in data:
ax.set_ylim(data['y_lim'])
if 'x_lim' in data:
ax.set_xlim(data['x_lim'])
# Draw legend only if labels were set (HOW TO DO IT?)
# if ax("has_some_label_set"):
ax.legend(loc='upper right', prop={'size': 6})
ax.title.set_fontsize(7)
ax.xaxis.set_tick_params(labelsize=6)
ax.xaxis.set_tick_params(direction='in')
ax.xaxis.label.set_size(7)
ax.yaxis.set_tick_params(labelsize=6)
ax.yaxis.set_tick_params(direction='in')
ax.yaxis.label.set_size(7)
def prettify_second_axes(ax):
ax.yaxis.set_tick_params(labelsize=7)
ax.yaxis.set_tick_params(labelcolor='red')
ax.yaxis.label.set_size(7)
def compare_plot(ax, data):
line1 = countour_every(ax, 10, **data[0])
if 'label' in data[0]:
line1.set_label(data[0]['label'])
line2 = countour_every(ax, 10, **data[1])
if 'label' in data[1]:
line2.set_label(data[1]['label'])
ax2 = ax.twinx()
line3 = ax.plot(
data[0]['x_data'],
data[0]['y_data']-data[1]['y_data'], '-',
color='red', alpha=.2, zorder=1)
prettify_axes(ax, data[0])
prettify_second_axes(ax2)
d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)
compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])
fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')
fig.savefig('demo.png', dpi=300)
That generates the following image
We can see that the X axis and the Y axis is correctly shared, but the secondary twin axis, is repeated in all subplots.
Also the secondary axis isn't scaling correctly to fit the data. (that should occurs independently of the principal y axis being limited).
You will need to share the twin axes manually and also remove the ticklabels
def compare_plot(ax, data):
# ...
ax2 = ax.twinx()
# ...
return ax2
sax1 = compare_plot(axes[0][0], [d0, d1])
sax2 = compare_plot(axes[0][1], [d0, d2])
sax3 = compare_plot(axes[1][0], [d1, d0])
sax4 = compare_plot(axes[1][1], [d3, d2])
for sax in [sax2, sax3, sax4]:
sax1.get_shared_y_axes().join(sax1, sax)
sax1.autoscale()
for sax in [sax1,sax3]:
sax.yaxis.set_tick_params(labelright=False)
Related
I have the code below which produces the output I want.
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
plt.style.use('ggplot')
%matplotlib inline
data = dict({'Variable_Grouping':['Type_A', 'Type_A', 'Type_A', 'Type_C', 'Type_C', 'Type_C', 'Type_C', 'Type_D', 'Type_D', 'Type_E', 'Type_E', 'Type_E', 'Type_H', 'Type_H'], 'Variable':['a1', 'a2', 'a3', 'c1', 'c2', 'c3', 'c4', 'd1', 'd2', 'e1', 'e2', 'e3', 'h1', 'h2'], 'Count':[5, 3, 8, 4, 3, 9, 5, 3, 8, 5, 3, 8, 5, 3],'Percent':[0.0625, 0.125, 0.4375, 0.0, 0.125, 0.5, 0.02, 0.125, 0.03, 0.0625, 0.05, 0.44, 0.07, 0.023]})
to_plot = pd.DataFrame(data)
g = sns.FacetGrid(to_plot, col='Variable_Grouping', col_wrap = 2, sharex=False, sharey = False, height = 5, aspect = 1, margin_titles=True)
g=g.map(plt.bar, "Variable","Count").add_legend()
for ax, (_, subdata) in zip(g.axes, to_plot.groupby('Variable_Grouping')):
ax2=ax.twinx()
subdata.plot(x='Variable',y='Percent', ax = ax2, legend=True, color='g', label = 'Percent')
ax2.set_ylabel('Percent')
ax2.grid(False)
for ax in g.axes.flatten():
ax.tick_params(labelbottom=True, labelrotation = 90)
g.fig.suptitle('Analysis', fontsize=16, fontweight = 'demibold', y = 1.02)
g.fig.subplots_adjust(hspace=0.3, wspace=0.7, right = 0.9)
plt.show();
Now I am using matplotlib.backends.backend_pdf to plot the figures in pdf. I want 4 figures per page.
with PdfPages('Analysis.pdf') as pdf:
g = sns.FacetGrid(to_plot, col='Variable_Grouping', col_wrap = 2, sharex=False, sharey = False, height = 5, aspect = 1, margin_titles=True)
g=g.map(plt.bar, "Variable","Count").add_legend()
for ax, (_, subdata) in zip(g.axes, to_plot.groupby('Variable_Grouping')):
ax2=ax.twinx()
subdata.plot(x='Variable',y='Percent', ax = ax2, legend=True, color='g', label = 'Percent')
ax2.set_ylabel('Percent')
ax2.grid(False)
for ax in g.axes.flatten():
ax.tick_params(labelbottom=True, labelrotation = 90)
g.fig.suptitle('Analysis', fontsize=16, fontweight = 'demibold', y = 1.02)
g.fig.subplots_adjust(hspace=0.3, wspace=0.7, right = 0.9)
pdf.savefig(bbox_inches = 'tight')
plt.close();
The code above gives me all the plots in a single page as expected.
def grouper(iterable, n, fillvalue=None):
from itertools import zip_longest
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
if len(to_plot['Variable_Grouping'].unique()) < 4:
N_plots_per_page =len(to_plot['Variable_Grouping'].unique())
elif len(to_plot['Variable_Grouping'].unique()) >= 4:
N_plots_per_page = 4
with PdfPages('Analysis.pdf') as pdf:
for cols in grouper(to_plot['Variable_Grouping'].unique(), N_plots_per_page):
g = sns.FacetGrid(to_plot, col='Variable_Grouping', col_wrap = 2, sharex=False, sharey = False, height = 5, aspect = 1, margin_titles=True)
g=g.map(plt.bar, "Variable","Count").add_legend()
for ax, (_, subdata) in zip(g.axes, to_plot.groupby('Variable_Grouping')):
ax2=ax.twinx()
subdata.plot(x='Variable',y='Percent', ax = ax2, legend=True, color='g', label = 'Percent')
ax2.set_ylabel('Percent')
ax2.grid(False)
for ax in g.axes.flatten():
ax.tick_params(labelbottom=True, labelrotation = 90)
g.fig.suptitle('Analysis', fontsize=16, fontweight = 'demibold', y = 1.02)
g.fig.subplots_adjust(hspace=0.3, wspace=0.7, right = 0.9)
pdf.savefig(bbox_inches = 'tight')
plt.show()
plt.close();
In the code above I have tried using the grouper function (https://docs.python.org/3/library/itertools.html#itertools-recipes). This was also mentioned in Export huge seaborn chart into pdf with multiple pages and this repeats all the graphs in all the pages.
I wanted to enquire if there is an easy way to get 4 graphs per page or what's wrong with the above code I used using the grouper function which is repeating the graphs. Any help will be appreciated. Thanks.
The problem is, even you try to get the number of plots per page, you take the whole data inside the loop to plot with to_plot. You need to filter your to_plot with the cols you get by your grouper and your code will work.
The only changes I made is create the variable data_per_page and replace that with to_plot inside of sns.FaceGrid and in for ax, (_,subdata) in zip(...).
with PdfPages('Analysis.pdf') as pdf:
for cols in grouper(to_plot['Variable_Grouping'].unique(), N_plots_per_page):
data_per_page = to_plot.loc[to_plot['Variable_Grouping'].isin(cols)]
g = sns.FacetGrid(data_per_page, col='Variable_Grouping', col_wrap = 2, sharex=False, sharey = False, height = 5, aspect = 1, margin_titles=True)
g=g.map(plt.bar, "Variable","Count").add_legend()
for ax, (_,subdata) in zip(g.axes, data_per_page.groupby(['Variable_Grouping'])):
ax2=ax.twinx()
subdata.plot(x='Variable',y='Percent', ax = ax2, legend=True, color='g', label = 'Percent')
ax2.set_ylabel('Percent')
ax2.grid(False)
for ax in g.axes.flatten():
ax.tick_params(labelbottom=True, labelrotation = 90)
g.fig.suptitle('Analysis', fontsize=16, fontweight = 'demibold', y = 1.02)
g.fig.subplots_adjust(hspace=0.3, wspace=0.7, right = 0.9)
pdf.savefig(bbox_inches='tight')
plt.show()
plt.close()
As a result I get a pdf with 2 pages, on the first there are 4 plots, and on the second only 1.
I'm plotting a graph with 3 y-axis, and two of them have scietific notations. However, they overlap on the top-right of the graph. I'd like to have them separated, and if possible on top of their axis. Here's how i plot the graph, and a picture of the result, where you can clearly see the overlapping :
https://i.stack.imgur.com/G2K9A.png (The y-axis on the right overlap a bit too, but I know how to correct that)
import numpy as np, matplotlib.pyplot as plt
a = np.arange(-1*10**-5, 10**-5, (10**-5+10**-5)/10)
b = np.arange(-2*10**-7, 2*10**-7, (2*10**-7+2*10**-7)/10)
c = np.arange(-3*10**-6, 3*10**-6, (3*10**-6+3*10**-6)/10)
x = np.arange(0, 100, 100/10)
fig, ax = plt.subplots(num=1, figsize = (15, 10))
fig.subplots_adjust(right=0.75)
twin1 = ax.twinx()
twin2 = ax.twinx()
twin2.spines['right'].set_position(("axes", 1.1))
p1, = ax.plot(x, a, color = 'r', linewidth = 2, label="y1")
p2, = twin1.plot(x, b, color = 'g', linewidth = 2, label="y2")
p3, = twin2.plot(x, c, color = 'b', linewidth = 2, label="y3")
ax.set_xlabel("Time (s)", fontsize=35)
ax.set_ylabel("y1", fontsize=35)
twin1.set_ylabel("y2", fontsize=35)
twin2.set_ylabel("y3", fontsize=35)
ax.yaxis.label.set_color(p1.get_color())
twin1.yaxis.label.set_color(p2.get_color())
twin2.yaxis.label.set_color(p3.get_color())
ax.tick_params(axis='y', colors=p1.get_color(), labelsize=30)
twin1.tick_params(axis='y', colors=p2.get_color(), labelsize=30)
twin1.yaxis.offsetText.set_fontsize(30)
twin2.tick_params(axis='y', colors=p3.get_color(), labelsize=30)
twin2.yaxis.offsetText.set_fontsize(30)
ax.tick_params(axis='x', labelsize=30)
min_axis_x, max_axis_x = x.min(), x.max()
min_axis_y, max_axis_y = a.min(), a.max()
min_axis_y1, max_axis_y1 = b.min(), b.max()
min_axis_y2, max_axis_y2 = c.min(), c.max()
ax.legend(handles=[p1, p2, p3], fontsize=35)
plt.title("y1, y2, y3 = f(t)", fontsize=45)
plt.show()
I want to create subplots with Matplotlib by looping over my data. However, I don't get the annotations into the correct position, apparently not even into the correct subplot. Also, the common x- and y-axis labels don't work.
My real data is more complex but here is an example that reproduces the error:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
# create data
distributions = []
first_values = []
second_values = []
for i in range(4):
distributions.append(np.random.normal(0, 0.5, 100))
first_values.append(np.random.uniform(0.7, 1))
second_values.append(np.random.uniform(0.7, 1))
# create subplot
fig, axes = plt.subplots(2, 2, figsize = (15, 10))
legend_elements = [Line2D([0], [0], color = '#76A29F', lw = 2, label = 'distribution'),
Line2D([0], [0], color = '#FEB302', lw = 2, label = '1st value', linestyle = '--'),
Line2D([0], [0], color = '#FF5D3E', lw = 2, label = '2nd value')]
# loop over data and create subplots
for data in range(4):
if data == 0:
position = axes[0, 0]
if data == 1:
position = axes[0, 1]
if data == 2:
position = axes[1, 0]
if data == 3:
position = axes[1, 1]
dist = distributions[data]
first = first_values[data]
second = second_values[data]
sns.histplot(dist, alpha = 0.5, kde = True, stat = 'density', bins = 20, color = '#76A29F', ax = position)
sns.rugplot(dist, alpha = 0.5, color = '#76A29F', ax = position)
position.annotate(f'{np.mean(dist):.2f}', (np.mean(dist), 0.825), xycoords = ('data', 'figure fraction'), color = '#76A29F')
position.axvline(first, 0, 0.75, linestyle = '--', alpha = 0.75, color = '#FEB302')
position.axvline(second, 0, 0.75, linestyle = '-', alpha = 0.75, color = '#FF5D3E')
position.annotate(f'{first:.2f}', (first, 0.8), xycoords = ('data', 'figure fraction'), color = '#FEB302')
position.annotate(f'{second:.2f}', (second, 0.85), xycoords = ('data', 'figure fraction'), color = '#FF5D3E')
position.set_xticks(np.arange(round(min(dist), 1) - 0.1, round(max(max(dist), max([first]), max([second])), 1) + 0.1, 0.1))
plt.xlabel("x-axis name")
plt.ylabel("y-axis name")
plt.legend(handles = legend_elements, bbox_to_anchor = (1.5, 0.5))
plt.show()
The resulting plot looks like this:
What I want is to have
the annotations in the correct subplot next to the vertical lines / the mean of the distribution
shared x- and y-labels for all subplot or at least for each row / column
Any help is highly appreciated!
If you use the function to make the subplot a single array (axes.flatten()) and modify it to draw the graph sequentially, you can draw the graph. The colors of the annotations have been partially changed for testing purposes.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
np.random.seed(202000104)
# create data
distributions = []
first_values = []
second_values = []
for i in range(4):
distributions.append(np.random.normal(0, 0.5, 100))
first_values.append(np.random.uniform(0.7, 1))
second_values.append(np.random.uniform(0.7, 1))
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
legend_elements = [Line2D([0], [0], color = '#76A29F', lw = 2, label = 'distribution'),
Line2D([0], [0], color = '#FEB302', lw = 2, label = '1st value', linestyle = '--'),
Line2D([0], [0], color = '#FF5D3E', lw = 2, label = '2nd value')]
for i,ax in enumerate(axes.flatten()):
sns.histplot(distributions[i], alpha=0.5, kde=True, stat='density', bins=20, color='#76A29F', ax=ax)
sns.rugplot(distributions[i], alpha=0.5, color='#76A29F', ax=ax)
ax.annotate(f'{np.mean(distributions[i]):.2f}', (np.mean(distributions[i]), 0.825), xycoords='data', color='red')
ax.axvline(first_values[i], 0, 0.75, linestyle = '--', alpha = 0.75, color = '#FEB302')
ax.axvline(second_values[i], 0, 0.75, linestyle = '-', alpha = 0.75, color = '#FF5D3E')
ax.annotate(f'{first_values[i]:.2f}', (first_values[i], 0.8), xycoords='data', color='#FEB302')
ax.annotate(f'{second_values[i]:.2f}', (second_values[i], 0.85), xycoords='data', color = '#FF5D3E')
ax.set_xticks(np.arange(round(min(distributions[i]), 1) - 0.1, round(max(max(distributions[i]), max([first_values[i]]), max([second_values[i]])), 1) + 0.1, 0.1))
plt.xlabel("x-axis name")
plt.ylabel("y-axis name")
plt.legend(handles = legend_elements, bbox_to_anchor = (1.35, 0.5))
plt.show()
I have a graph with 2 lines and used the check button to make visible/invisible lines. However, I could not manage to make the same thing to annotations. I wanna when I check label2 make visible/invisible annotations too. I couldn't find any solution from here or web. Any idea?
Here is example of my code :
import matplotlib.pyplot as plt
from matplotlib.widgets import CheckButtons
kansekeri=[333, 111, 111]
inshiz=[3.0, 3.0, 2.5]
zaman=['01-04-2021 04:02', '01-04-2021 04:02', '01-04-2021 04:02']
fig = plt.figure(figsize=(10, 5))
fig.canvas.set_window_title('-')
ax1 = fig.add_subplot()
l0,=ax1.plot(zaman, inshiz, color='tab:blue', marker='.', label='İnsülin Hızı')
ax1.set_xlabel('Tarih - Saat')
ax1.set_ylabel('İnsülin Hızı (u/h)', color='tab:blue')
ax1.tick_params(axis='y',labelcolor='tab:blue')
ax1.tick_params(axis='x', labelsize=7)
ax2 = ax1.twinx()
l1,=ax2.plot(zaman, kansekeri, color='tab:red', marker='.', label="Kan Şekeri")
ax2.set_ylabel('Kan Şekeri',color='tab:red')
ax2.tick_params(labelcolor='tab:red')
for x,y in zip(zaman,kansekeri):
label = "{:.1f}".format(y)
ann=plt.annotate(label, # this is the text
(x,y),
textcoords="offset points",
xytext=(0,7),
ha='center',fontsize=7)
lines = [l0, l1]
rax = plt.axes([0.05, 0.7, 0.1, 0.15])
labels = [str(line.get_label()) for line in lines]
visibility = [line.get_visible() for line in lines]
check = CheckButtons(rax, labels, visibility)
def func(label):
index = labels.index(label)
lines[index].set_visible(not lines[index].get_visible())
plt.draw()
check.on_clicked(func)
plt.subplots_adjust(left=0.2)
plt.grid()
plt.show()
You could also save the annotations in a list: annotations = [plt.annotate(f"{y:.1f}", (x, y), ...) for x, y in zip(zaman, kansekeri)] and then do something similar as with the lines.
Here is some example code:
import matplotlib.pyplot as plt
from matplotlib.widgets import CheckButtons
kansekeri = [333, 111, 111]
inshiz = [3.0, 3.0, 2.5]
zaman = ['01-04-2021 04:02', '01-04-2021 04:02', '01-04-2021 04:02']
fig = plt.figure(figsize=(10, 5))
fig.canvas.set_window_title('-')
ax1 = fig.add_subplot()
l0, = ax1.plot(zaman, inshiz, color='tab:blue', marker='.', label='İnsülin Hızı')
ax1.set_xlabel('Tarih - Saat')
ax1.set_ylabel('İnsülin Hızı (u/h)', color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')
ax1.tick_params(axis='x', labelsize=7)
ax2 = ax1.twinx()
l1, = ax2.plot(zaman, kansekeri, color='tab:red', marker='.', label="Kan Şekeri")
ax2.set_ylabel('Kan Şekeri', color='tab:red')
ax2.tick_params(labelcolor='tab:red')
annotations = [plt.annotate(f"{y:.1f}",
(x, y),
textcoords="offset points",
xytext=(0, 7),
ha='center', fontsize=7)
for x, y in zip(zaman, kansekeri)]
lines = [l0, l1]
rax = plt.axes([0.05, 0.7, 0.1, 0.15])
labels = [str(line.get_label()) for line in lines]
visibility = [line.get_visible() for line in lines]
check = CheckButtons(rax, labels, visibility)
def func(label):
index = labels.index(label)
new_line_visibility = not lines[index].get_visible()
lines[index].set_visible(new_line_visibility)
if index == 1:
for ann in annotations:
ann.set_visible(new_line_visibility)
plt.draw()
check.on_clicked(func)
plt.subplots_adjust(left=0.2)
plt.grid()
plt.show()
Following my previous question that didn't get any answer, I tried to solve my problem of adding colorbar instead of legend to my plots. There are couple of problems that I couldn't solve yet.
Update:
I want to move the colorbar to the proper position on the right of the plot.
I generate two plots with the same instruction but the second one looks completely different and I couldn't understand what caused this problem.
Here is my code:
import numpy as np
import pylab as plt
from matplotlib import rc,rcParams
rc('text',usetex=True)
rcParams.update({'font.size':10})
import matplotlib.cm as cm
from matplotlib.ticker import NullFormatter
import matplotlib as mpl
def plot(Z_s,CWL,filter_id,spectral_type,model_mag,mag,plot_name):
f= ['U38','B','V','R','I','MB420','MB464','MB485','MB518','MB571','MB604','MB646','MB696','MB753','MB815','MB856','MB914']
wavetable=CWL/(1+Z_s)
dd=model_mag-mag
nplist=['E', 'Sbc', 'Scd', 'Irr', 'SB3', 'SB2']
minimum,maximum=(0.,16.)
Z = [[0,0],[0,0]]
levels = list(np.linspace(0, 1, len(f)))
NUM_COLORS = len(f)
cm = plt.get_cmap('gist_rainbow')
mycolor=[]
for i in range(NUM_COLORS):
mycolor.append( cm(1.*i/NUM_COLORS)) # color will now be an RGBA tuple
mymap = mpl.colors.LinearSegmentedColormap.from_list('mycolors',mycolor)
CS3 = plt.contourf(Z, levels, cmap=mymap)
plt.clf()
FILTER=filter_id
SED=spectral_type
for (j,d) in enumerate(nplist):
bf=(SED==j)
if (j<3):
k=j
i_subplot = k + 1
fig = plt.figure(1, figsize=(5,5))
ax = fig.add_subplot(3,1,i_subplot)
for i in range(len(f)):
bb=np.where(FILTER[bf]==i)[0]
r=mycolor[i][0]
g=mycolor[i][1]
b=mycolor[i][2]
ax.scatter(wavetable[bb], dd[bb], s=1, color=(r,g,b))
if (k<2):
ax.xaxis.set_major_formatter( NullFormatter() )
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
else:
ax.set_xlabel(r'WL($\AA$)',fontsize=10)
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
fig.subplots_adjust(wspace=0,hspace=0)
ax.axhline(y=0,color='k')
ax.set_xlim(1000,9000)
ax.set_ylim(-3,3)
ax.set_xticks(np.linspace(1000, 9000, 16, endpoint=False))
ax.set_yticks(np.linspace(-3, 3, 4, endpoint=False))
ax.text(8500,2.1,nplist[j], {'color': 'k', 'fontsize': 10})
fontsize=8
for tick in ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
for tick in ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
if (j==2):
cbar_ax = fig.add_axes([0.9, 0.15, 0.05, 0.7])
cbar=plt.colorbar(CS3, cax=cbar_ax, ticks=range(0,len(f)),orientation='vertical')
cbar.ax.get_yaxis().set_ticks([])
for s, lab in enumerate(f):
cbar.ax.text( 0.08,(0.95-0.01)/float(len(f)-1) * s, lab, fontsize=8,ha='left')
fname = plot_name+'.'+nplist[0]+'.'+nplist[1]+'.'+nplist[2]+'.pdf'
plt.savefig(fname)
plt.close()
else:
k=j-3
i_subplot = k + 1
fig = plt.figure(1, figsize=(5,5))
ax = fig.add_subplot(3,1,i_subplot)
for i in range(len(f)):
bb=np.where(FILTER[bf]==i)[0]
r=mycolor[i][0]
g=mycolor[i][1]
b=mycolor[i][2]
ax.scatter(wavetable[bb], dd[bb], s=1, color=(r,g,b))
if (k<2):
ax.xaxis.set_major_formatter( NullFormatter() )
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
else:
ax.set_xlabel(r'WL($\AA$)',fontsize=10)
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
fig.subplots_adjust(wspace=0,hspace=0)
ax.axhline(y=0,color='k')
ax.set_xlim(1000,9000)
ax.set_ylim(-3,3)
ax.set_xticks(np.linspace(1000, 9000, 16, endpoint=False))
ax.set_yticks(np.linspace(-3, 3, 4, endpoint=False))
ax.text(8500,2.1,nplist[j], {'color': 'k', 'fontsize': 10})
fontsize=8
for tick in ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
for tick in ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
if (j==5):
cbar_ax = fig.add_axes([0.9, 0.15, 0.05, 0.7])
cbar=plt.colorbar(CS3, cax=cbar_ax, ticks=range(0,len(f)),orientation='vertical')
cbar.ax.get_yaxis().set_ticks([])
for s, lab in enumerate(f):
cbar.ax.text( 0.08,(0.95-0.01)/float(len(f)-1) * s, lab , fontsize=8,ha='left')
fname = plot_name+'.'+nplist[3]+'.'+nplist[4]+'.'+nplist[5]+'.pdf'
plt.savefig(fname)
plt.close()
a=np.loadtxt('calibration.photometry.information.capak.cat')
Z_s=a[:,0]
CWL=a[:,1]
filter_id=a[:,2]
spectral_type=a[:,3]
model_mag=a[:,4]
mag=a[:,5]
plot_name='test'
plot(Z_s,CWL,filter_id,spectral_type,model_mag,mag,plot_name)
you can also download the data from here.
I will appreciate to get any help.
You can use plt.subplots() passing the gridspec_kw parameter to adjust the axes' aspect ratio in a very flexible way, and then select the top axes to include the colorbar.
I've worked on your code simplifying it quite a bit. Furthermore, I've changed many things in your code such as: PEP8, removed repeated calls to plt.savefig()and ax methods. The result is:
import numpy as np
import pylab as plt
from matplotlib import rc, rcParams, colors
rc('text', usetex=True)
rcParams['font.size'] = 10
rcParams['axes.labelsize'] = 8
def plot(Z_s, CWL, filter_id, spectral_type, model_mag, mag, plot_name):
f= ['U38', 'B', 'V', 'R', 'I', 'MB420', 'MB464', 'MB485', 'MB518',
'MB571', 'MB604', 'MB646', 'MB696', 'B753', 'MB815', 'MB856',
'MB914']
wavetable = CWL/(1+Z_s)
dd = model_mag-mag
nplist = ['E', 'Sbc', 'Scd', 'Irr', 'SB3', 'SB2']
minimum, maximum = (0., 16.)
Z = [[0, 0],[0, 0]]
levels = list(np.linspace(0, 1, len(f)+1))
NUM_COLORS = len(f)
cmap = plt.get_cmap('gist_rainbow')
mycolor = []
for i in range(NUM_COLORS):
mycolor.append(cmap(1.*i/NUM_COLORS))
mymap = colors.LinearSegmentedColormap.from_list('mycolors', mycolor)
CS3 = plt.contourf(Z, levels, cmap=mymap)
coords = CS3.get_array()
coords = coords[:-1] + np.diff(coords)/2.
FILTER = filter_id
SED = spectral_type
dummy = 2
xmin = 1000
xmax = 9000
ymin = -3
ymax = 3
fig, axes = plt.subplots(nrows=5, figsize=(5, 6),
gridspec_kw=dict(height_ratios=[0.35, 0.05, 1, 1, 1]))
fig2, axes2 = plt.subplots(nrows=5, figsize=(5, 6),
gridspec_kw=dict(height_ratios=[0.35, 0.05, 1, 1, 1]))
fig.subplots_adjust(wspace=0, hspace=0)
fig2.subplots_adjust(wspace=0, hspace=0)
axes_all = np.concatenate((axes[dummy:], axes2[dummy:]))
dummy_axes = np.concatenate((axes[:dummy], axes2[:dummy]))
for ax in axes_all:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.axhline(y=0, color='k')
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_xticks([])
ax.set_yticks(np.linspace(ymin, ymax, 4, endpoint=False))
ax.set_ylabel(r'$\Delta$ MAG', fontsize=10)
axes[-1].set_xticks(np.linspace(xmin, xmax, 16, endpoint=False))
axes2[-1].set_xticks(np.linspace(xmin, xmax, 16, endpoint=False))
plt.setp(axes[-1].xaxis.get_majorticklabels(), rotation=30)
plt.setp(axes2[-1].xaxis.get_majorticklabels(), rotation=30)
axes[-1].set_xlabel(r'WL($\AA$)', fontsize=10)
axes2[-1].set_xlabel(r'WL($\AA$)', fontsize=10)
for ax in dummy_axes:
for s in ax.spines.values():
s.set_visible(False)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.set_xticks([])
ax.set_yticks([])
for axes_i in [axes, axes2]:
cbar = plt.colorbar(CS3, ticks=[], orientation='horizontal',
cax=axes_i[0])
for s, lab in enumerate(f):
cbar.ax.text(coords[s], 0.5, lab, fontsize=8, va='center',
ha='center', rotation=90,
transform=cbar.ax.transAxes)
for (j, d) in enumerate(nplist):
bf = (SED==j)
if (j<3):
k = j
ax = axes[k+dummy]
ax.text(8500, 2.1, nplist[j], {'color': 'k', 'fontsize': 10})
for i in range(len(f)):
bb = np.where(FILTER[bf]==i)[0]
ax.scatter(wavetable[bb], dd[bb], s=1, color=mycolor[i])
else:
k = j-3
ax = axes2[k+dummy]
ax.text(8500, 2.1, nplist[j], {'color': 'k', 'fontsize': 10})
for i in range(len(f)):
bb = np.where(FILTER[bf]==i)[0]
ax.scatter(wavetable[bb], dd[bb], s=1, color=mycolor[i])
fname = '.'.join([plot_name, nplist[0], nplist[1], nplist[2], 'png'])
fig.savefig(fname)
fname = '.'.join([plot_name, nplist[3], nplist[4], nplist[5], 'png'])
fig2.savefig(fname)
if __name__=='__main__':
a = np.loadtxt('calibration.photometry.information.capak.cat')
Z_s = a[:, 0]
CWL = a[:, 1]
filter_id = a[:, 2]
spectral_type = a[:, 3]
model_mag = a[:, 4]
mag = a[:, 5]
plot_name = 'test'
plot(Z_s, CWL, filter_id, spectral_type, model_mag, mag, plot_name)
which gives: