Related
I have a plot consisting of multiple elements and I wish to have a selection enlarged with inset_axes. I have followed the manual and several other posts but it is only creating an empty square.
I have 1500 lines of code where I add elements to that plot at different places, thus I wish to create a zoom at the end to the whole plot.
Here is what is my output:
and here is the code that leads to this plot (data omitted, too large)
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
# this is done in a separate function (called only once)
# variables defined earlier, omitting - some variables are set manually, some are observed
for i_plot_buildings in range(0, len(Building_center_coords_main)):
ax.plot(x_building_corners_main[i_plot_buildings], y_building_corners_main[i_plot_buildings], 'k-')
plt.scatter(UEs_coordinates[:, 1], UEs_coordinates[:, 0], s=50, marker='.', c="c")
plt.scatter(BSs_coordinates[:, 0], BSs_coordinates[:, 1], marker='^', c="r", zorder=3)
ax.set_aspect('equal', adjustable='box')
Plot_last_drop_indicator = 0
arry = np.empty((1000, 1000), int) #this fills with info about wind in later code (omitted)
plt.imshow(arry, cmap=plt.cm.Greens, interpolation='nearest')
plt.ylabel("Y [m]")
plt.xlabel("X [m]")
bar = plt.colorbar()
bar.set_label(r'Wind speed $[ms^{-1}]$', rotation=270)
#and now I want the zoom, not working...
axins = ax.inset_axes([0.55, 0.55, 0.4, 0.4]) # set the area where to enlarge selection
# the selection
x1, x2, y1, y2 = 0,100,0,100
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)
axins.set_xticklabels([])
axins.set_yticklabels([])
ax.indicate_inset_zoom(axins, edgecolor="black")
In case anyone would ever try to do some complex stuff like me I keep this question and provide answer I have managed to run. As mentioned by comments, you have to add a brand new plot inside the original plot with new data. Here is the code:
fig, ax = plt.subplots()
# this is done in a separate function (called only once)
# variables defined earlier, omitting - some variables are set manually, some are observed
for i_plot_buildings in range(0, len(Building_center_coords_main)):
ax.plot(x_building_corners_main[i_plot_buildings], y_building_corners_main[i_plot_buildings], 'k-')
plt.scatter(UEs_coordinates[:, 1], UEs_coordinates[:, 0], s=50, marker='.', c="c")
plt.scatter(BSs_coordinates[:, 0], BSs_coordinates[:, 1], marker='^', c="r", zorder=3)
ax.set_aspect('equal', adjustable='box')
Plot_last_drop_indicator = 0
arry = np.empty((1000, 1000), int) #this fills with info about wind in later code (omitted)
plt.imshow(arry, cmap=plt.cm.Greens, interpolation='nearest')
plt.ylabel("Y [m]")
plt.xlabel("X [m]")
bar = plt.colorbar()
bar.set_label(r'Wind speed $[ms^{-1}]$', rotation=270)
# adding circles and crosses
plt.scatter(47, 64, marker='x',c="g")
circle1 = plt.Circle((47, 64), 8*4, color='g', fill=False)
plt.gca().add_patch(circle1)
plt.scatter(41, 56, marker='x',c="r")
circle1 = plt.Circle((41, 56), 6*4, color='r', fill=False)
plt.gca().add_patch(circle1)
plt.scatter(726, 672, marker='x',c="g")
plt.scatter(755, 658, marker='x',c="r")
circle1 = plt.Circle((726, 672), 9*4, color='g' ,fill=False)
plt.gca().add_patch(circle1)
circle1 = plt.Circle((755, 658), 7*4, color='r', fill=False)
plt.gca().add_patch(circle1)
plt.scatter(920, 51, marker='x',c="g")
circle1 = plt.Circle((920, 51), 8*4, color='g', fill=False)
plt.gca().add_patch(circle1)
plt.scatter(927, 41, marker='x',c="r")
circle1 = plt.Circle((927, 41), 5*4, color='r', fill=False)
plt.gca().add_patch(circle1)
#inverting y axis
plt.gca().invert_yaxis()
#selecting parent axis for later use
parent_axes = plt.gca()
axins = ax.inset_axes([0.1, 0.5, 0.4, 0.4]) # enlargement area
# area to zoom
x1, x2, y1, y2 = 680,780,620,720
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)
# new ax
ax3 = plt.gcf().add_axes([50,500,400,400])
# adding arry selection to the new plot
ax3.imshow(arry[680:780, 620:720], cmap=plt.cm.Greens, interpolation='nearest')
running through coordinates to select which to add (yes this is slow and can be done more smart but I am lazy)
for coord in UEcoords:
if coord[1] >= 680 and coord [1] < 780:
if coord[0] >= 620 and coord [0] < 720:
ax3.scatter(coord[1]-680, coord[0]-620, s=100, marker='.', c="c")
#show selected crosses and circles in area
ax3.scatter(726-680, 672-620, marker='x',c="g")
ax3.scatter(755-680, 658-620, marker='x',c="r")
circle1 = plt.Circle((726-680, 672-620), 9*4, color='g' ,fill=False)
ax3.add_patch(circle1)
circle1 =plt.Circle((755-680, 658-620), 7*4, color='r', fill=False)
ax3.add_patch(circle1)
#set limits and turn off labels
ax3.set_xlim(0,100)
ax3.set_ylim(0,100)
ax3.set_xticklabels([])
ax3.set_yticklabels([])
#set up inset position
ip = InsetPosition(parent_axes,[0.1, 0.5, 0.4, 0.4])
axins.set_axes_locator(ip)
axins.set_xticklabels([])
axins.set_yticklabels([])
# set the new axes (ax3) to the position of the linked axes
ax3.set_axes_locator(ip)
ax.indicate_inset_zoom(axins, edgecolor="black")
The final output looks like this:
I want to create subplots with Matplotlib by looping over my data. However, I don't get the annotations into the correct position, apparently not even into the correct subplot. Also, the common x- and y-axis labels don't work.
My real data is more complex but here is an example that reproduces the error:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
# create data
distributions = []
first_values = []
second_values = []
for i in range(4):
distributions.append(np.random.normal(0, 0.5, 100))
first_values.append(np.random.uniform(0.7, 1))
second_values.append(np.random.uniform(0.7, 1))
# create subplot
fig, axes = plt.subplots(2, 2, figsize = (15, 10))
legend_elements = [Line2D([0], [0], color = '#76A29F', lw = 2, label = 'distribution'),
Line2D([0], [0], color = '#FEB302', lw = 2, label = '1st value', linestyle = '--'),
Line2D([0], [0], color = '#FF5D3E', lw = 2, label = '2nd value')]
# loop over data and create subplots
for data in range(4):
if data == 0:
position = axes[0, 0]
if data == 1:
position = axes[0, 1]
if data == 2:
position = axes[1, 0]
if data == 3:
position = axes[1, 1]
dist = distributions[data]
first = first_values[data]
second = second_values[data]
sns.histplot(dist, alpha = 0.5, kde = True, stat = 'density', bins = 20, color = '#76A29F', ax = position)
sns.rugplot(dist, alpha = 0.5, color = '#76A29F', ax = position)
position.annotate(f'{np.mean(dist):.2f}', (np.mean(dist), 0.825), xycoords = ('data', 'figure fraction'), color = '#76A29F')
position.axvline(first, 0, 0.75, linestyle = '--', alpha = 0.75, color = '#FEB302')
position.axvline(second, 0, 0.75, linestyle = '-', alpha = 0.75, color = '#FF5D3E')
position.annotate(f'{first:.2f}', (first, 0.8), xycoords = ('data', 'figure fraction'), color = '#FEB302')
position.annotate(f'{second:.2f}', (second, 0.85), xycoords = ('data', 'figure fraction'), color = '#FF5D3E')
position.set_xticks(np.arange(round(min(dist), 1) - 0.1, round(max(max(dist), max([first]), max([second])), 1) + 0.1, 0.1))
plt.xlabel("x-axis name")
plt.ylabel("y-axis name")
plt.legend(handles = legend_elements, bbox_to_anchor = (1.5, 0.5))
plt.show()
The resulting plot looks like this:
What I want is to have
the annotations in the correct subplot next to the vertical lines / the mean of the distribution
shared x- and y-labels for all subplot or at least for each row / column
Any help is highly appreciated!
If you use the function to make the subplot a single array (axes.flatten()) and modify it to draw the graph sequentially, you can draw the graph. The colors of the annotations have been partially changed for testing purposes.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
np.random.seed(202000104)
# create data
distributions = []
first_values = []
second_values = []
for i in range(4):
distributions.append(np.random.normal(0, 0.5, 100))
first_values.append(np.random.uniform(0.7, 1))
second_values.append(np.random.uniform(0.7, 1))
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
legend_elements = [Line2D([0], [0], color = '#76A29F', lw = 2, label = 'distribution'),
Line2D([0], [0], color = '#FEB302', lw = 2, label = '1st value', linestyle = '--'),
Line2D([0], [0], color = '#FF5D3E', lw = 2, label = '2nd value')]
for i,ax in enumerate(axes.flatten()):
sns.histplot(distributions[i], alpha=0.5, kde=True, stat='density', bins=20, color='#76A29F', ax=ax)
sns.rugplot(distributions[i], alpha=0.5, color='#76A29F', ax=ax)
ax.annotate(f'{np.mean(distributions[i]):.2f}', (np.mean(distributions[i]), 0.825), xycoords='data', color='red')
ax.axvline(first_values[i], 0, 0.75, linestyle = '--', alpha = 0.75, color = '#FEB302')
ax.axvline(second_values[i], 0, 0.75, linestyle = '-', alpha = 0.75, color = '#FF5D3E')
ax.annotate(f'{first_values[i]:.2f}', (first_values[i], 0.8), xycoords='data', color='#FEB302')
ax.annotate(f'{second_values[i]:.2f}', (second_values[i], 0.85), xycoords='data', color = '#FF5D3E')
ax.set_xticks(np.arange(round(min(distributions[i]), 1) - 0.1, round(max(max(distributions[i]), max([first_values[i]]), max([second_values[i]])), 1) + 0.1, 0.1))
plt.xlabel("x-axis name")
plt.ylabel("y-axis name")
plt.legend(handles = legend_elements, bbox_to_anchor = (1.35, 0.5))
plt.show()
Is there a plot function available in Python that is same as MATLAB's stackedplot()?
stackedplot() in MATLAB can line plot several variables with the same X axis and are stacked vertically. Additionally, there is a scope in this plot that shows the value of all variables for a given X just by moving the cursor (please see the attached plot). I have been able to generate stacked subplots in Python with no issues, however, not able to add a scope like this that shows the value of all variables by moving the cursor. Is this feature available in Python?
This is a plot using MATLAB's stackedplot():
import pandas as pd
import numpy as np
from datetime import datetime, date, time
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.transforms as transforms
import mplcursors
from collections import Counter
import collections
def flatten(x):
result = []
for el in x:
if isinstance(x, collections.Iterable) and not isinstance(el, str):
result.extend(flatten(el))
else:
result.append(el)
return result
def shared_scope(sel):
sel.annotation.set_visible(False) # hide the default annotation created by mplcursors
x = sel.target[0]
for ax in axes:
for plot in plotStore:
da = plot.get_ydata()
if type(da[0]) is np.datetime64: #pd.Timestamp
yData = matplotlib.dates.date2num(da) # to numerical values
vals = np.interp(x, plot.get_xdata(), yData)
dates = matplotlib.dates.num2date(vals) # to matplotlib dates
y = datetime.strftime(dates,'%Y-%m-%d %H:%M:%S') # to strings
annot = ax.annotate(f'{y:.30s}', (x, vals), xytext=(15, 10), textcoords='offset points',
bbox=dict(facecolor='tomato', edgecolor='black', boxstyle='round', alpha=0.5))
sel.extras.append(annot)
else:
y = np.interp(x, plot.get_xdata(), plot.get_ydata())
annot = ax.annotate(f'{y:.2f}', (x, y), xytext=(15, 10), textcoords='offset points', arrowprops=dict(arrowstyle="->",connectionstyle="angle,angleA=0,angleB=90,rad=10"),
bbox=dict(facecolor='tomato', edgecolor='black', boxstyle='round', alpha=0.5))
sel.extras.append(annot)
vline = ax.axvline(x, color='k', ls=':')
sel.extras.append(vline)
trans = transforms.blended_transform_factory(axes[0].transData, axes[0].transAxes)
text1 = axes[0].text(x, 1.01, f'{x:.2f}', ha='center', va='bottom', color='blue', clip_on=False, transform=trans)
sel.extras.append(text1)
# Data to plot
data = pd.DataFrame(columns = ['timeOfSample','Var1','Var2'])
data.timeOfSample = ['2020-05-10 09:09:02','2020-05-10 09:09:39','2020-05-10 09:40:07','2020-05-10 09:40:45','2020-05-12 09:50:45']
data['timeOfSample'] = pd.to_datetime(data['timeOfSample'])
data.Var1 = [10,50,100,5,25]
data.Var2 = [20,55,70,60,50]
variables = ['timeOfSample',['Var1','Var2']] # variables to plot - Var1 and Var2 to share a plot
nPlot = len(variables)
dataPts = np.arange(0, len(data[variables[0]]), 1) # x values for plots
plotStore = [0]*len(flatten(variables)) # to store all the plots for annotation purposes later
fig, axes = plt.subplots(nPlot,1,sharex=True)
k=0
for i in range(nPlot):
if np.size(variables[i])==1:
yData = data[variables[i]]
line, = axes[i].plot(dataPts,yData,label = variables[i])
plotStore[k]=line
k = k+1
else:
for j in range(np.size(variables[i])):
yData = data[variables[i][j]]
line, = axes[i].plot(dataPts,yData,label = variables[i][j])
plotStore[k]=line
k = k+1
axes[i].set_ylabel(variables[i])
cursor = mplcursors.cursor(plotStore, hover=True)
cursor.connect('add', shared_scope)
plt.xlabel('Samples')
plt.show()
mplcursors can be used to create annotations while hovering, moving texts and vertical bars. sel.extras.append(...) helps to automatically hide the elements that aren't needed anymore.
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import mplcursors
import numpy as np
def shared_scope(sel):
x = sel.target[0]
annotation_text = f'x: {x:.2f}'
for ax, plot in zip(axes, all_plots):
y = np.interp(x, plot.get_xdata(), plot.get_ydata())
annotation_text += f'\n{plot.get_label()}: {y:.2f}'
vline = ax.axvline(x, color='k', ls=':')
sel.extras.append(vline)
sel.annotation.set_text(annotation_text)
trans = transforms.blended_transform_factory(axes[0].transData, axes[0].transAxes)
text1 = axes[0].text(x, 1.01, f'{x:.2f}', ha='center', va='bottom', color='blue', clip_on=False, transform=trans)
sel.extras.append(text1)
fig, axes = plt.subplots(figsize=(15, 10), nrows=3, sharex=True)
y1 = np.random.uniform(-1, 1, 100).cumsum()
y2 = np.random.uniform(-1, 1, 100).cumsum()
y3 = np.random.uniform(-1, 1, 100).cumsum()
all_y = [y1, y2, y3]
all_labels = ['Var1', 'Var2', 'Var3']
all_plots = [ax.plot(y, label=label)[0]
for ax, y, label in zip(axes, all_y, all_labels)]
for ax, label in zip(axes, all_labels):
ax.set_ylabel(label)
cursor = mplcursors.cursor(all_plots, hover=True)
cursor.connect('add', shared_scope)
plt.show()
Here is a version with separate annotations per subplot:
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import mplcursors
import numpy as np
def shared_scope(sel):
sel.annotation.set_visible(False) # hide the default annotation created by mplcursors
x = sel.target[0]
for ax, plot in zip(axes, all_plots):
y = np.interp(x, plot.get_xdata(), plot.get_ydata())
vline = ax.axvline(x, color='k', ls=':')
sel.extras.append(vline)
annot = ax.annotate(f'{y:.2f}', (x, y), xytext=(5, 0), textcoords='offset points',
bbox=dict(facecolor='tomato', edgecolor='black', boxstyle='round', alpha=0.5))
sel.extras.append(annot)
trans = transforms.blended_transform_factory(axes[0].transData, axes[0].transAxes)
text1 = axes[0].text(x, 1.01, f'{x:.2f}', ha='center', va='bottom', color='blue', clip_on=False, transform=trans)
sel.extras.append(text1)
fig, axes = plt.subplots(figsize=(15, 10), nrows=3, sharex=True)
y1 = np.random.uniform(-1, 1, 100).cumsum()
y2 = np.random.uniform(-1, 1, 100).cumsum()
y3 = np.random.uniform(-1, 1, 100).cumsum()
all_y = [y1, y2, y3]
all_labels = ['Var1', 'Var2', 'Var3']
all_plots = [ax.plot(y, label=label)[0]
for ax, y, label in zip(axes, all_y, all_labels)]
for ax, label in zip(axes, all_labels):
ax.set_ylabel(label)
cursor = mplcursors.cursor(all_plots, hover=True)
cursor.connect('add', shared_scope)
plt.show()
I am making this bar plot:
... using this code segment:
my_cmap = plt.get_cmap('copper')
plt.figure()
plt.set_cmap(my_cmap)
plt.pcolormesh(xx, yy, Z)
labels = ['Negative', 'Negative (doubtful)', 'Positive (doubtful)', 'Positive' ]
for i in [0, 1, 2, 3] :
plt.scatter(clustered_training_data[y==i, 0], clustered_training_data[y==i, 1], c=my_cmap(i / 3.0), label=labels[i], s=50, marker='o', edgecolor='white', alpha=0.7)
plt.scatter(lda_trans_eval[q == -1, 0], lda_trans_eval[q == -1, 1], c='green', label='Your patient', s=80, marker='h', edgecolor='white')
plt.legend(prop={'size':8})
Only one (second) color is always blue, regardless of chosen color map. Corresponding data points are correctly colored in the plot and I can't see the reason why pyplot colors the second label differently.
I can't reproduce it with dummy data. Does this have the problem when you run it?
import matplotlib.pyplot as plt
import numpy as np
my_cmap = plt.get_cmap('copper')
fig = plt.figure(figsize=(5,5))
plt.set_cmap(my_cmap)
X = np.linspace(-1,5,100)
Y = np.linspace(-1,5,100)
X,Y = np.meshgrid(X,Y)
Z = (X**2 + Y**2)
Z = Z.astype(int)
Z += (X**2 + Y**2) < .5
ax = plt.pcolormesh(X, Y, Z)
for i in [0,1,2,3]:
plt.scatter([i],[i],c=my_cmap(i / 3.0),label='i=%s'%str(i),
edgecolor='white', alpha=0.7)
plt.scatter([],[],c=my_cmap(1/3.0), label='empty data')
plt.scatter([3],[1],c='green',label='Force color')
plt.legend(loc=2, prop={'size':8})
from os.path import realpath, basename
s = basename(realpath(__file__))
fig.savefig(s.split('.')[0])
plt.show()
This happened to me. I fixed it by using color instead of c.
plt.scatter(clustered_training_data[y==i, 0], clustered_training_data[y==i, 1], color=my_cmap(i / 3.0), label=labels[i], s=50, marker='o', edgecolor='white', alpha=0.7)
Following my previous question that didn't get any answer, I tried to solve my problem of adding colorbar instead of legend to my plots. There are couple of problems that I couldn't solve yet.
Update:
I want to move the colorbar to the proper position on the right of the plot.
I generate two plots with the same instruction but the second one looks completely different and I couldn't understand what caused this problem.
Here is my code:
import numpy as np
import pylab as plt
from matplotlib import rc,rcParams
rc('text',usetex=True)
rcParams.update({'font.size':10})
import matplotlib.cm as cm
from matplotlib.ticker import NullFormatter
import matplotlib as mpl
def plot(Z_s,CWL,filter_id,spectral_type,model_mag,mag,plot_name):
f= ['U38','B','V','R','I','MB420','MB464','MB485','MB518','MB571','MB604','MB646','MB696','MB753','MB815','MB856','MB914']
wavetable=CWL/(1+Z_s)
dd=model_mag-mag
nplist=['E', 'Sbc', 'Scd', 'Irr', 'SB3', 'SB2']
minimum,maximum=(0.,16.)
Z = [[0,0],[0,0]]
levels = list(np.linspace(0, 1, len(f)))
NUM_COLORS = len(f)
cm = plt.get_cmap('gist_rainbow')
mycolor=[]
for i in range(NUM_COLORS):
mycolor.append( cm(1.*i/NUM_COLORS)) # color will now be an RGBA tuple
mymap = mpl.colors.LinearSegmentedColormap.from_list('mycolors',mycolor)
CS3 = plt.contourf(Z, levels, cmap=mymap)
plt.clf()
FILTER=filter_id
SED=spectral_type
for (j,d) in enumerate(nplist):
bf=(SED==j)
if (j<3):
k=j
i_subplot = k + 1
fig = plt.figure(1, figsize=(5,5))
ax = fig.add_subplot(3,1,i_subplot)
for i in range(len(f)):
bb=np.where(FILTER[bf]==i)[0]
r=mycolor[i][0]
g=mycolor[i][1]
b=mycolor[i][2]
ax.scatter(wavetable[bb], dd[bb], s=1, color=(r,g,b))
if (k<2):
ax.xaxis.set_major_formatter( NullFormatter() )
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
else:
ax.set_xlabel(r'WL($\AA$)',fontsize=10)
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
fig.subplots_adjust(wspace=0,hspace=0)
ax.axhline(y=0,color='k')
ax.set_xlim(1000,9000)
ax.set_ylim(-3,3)
ax.set_xticks(np.linspace(1000, 9000, 16, endpoint=False))
ax.set_yticks(np.linspace(-3, 3, 4, endpoint=False))
ax.text(8500,2.1,nplist[j], {'color': 'k', 'fontsize': 10})
fontsize=8
for tick in ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
for tick in ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
if (j==2):
cbar_ax = fig.add_axes([0.9, 0.15, 0.05, 0.7])
cbar=plt.colorbar(CS3, cax=cbar_ax, ticks=range(0,len(f)),orientation='vertical')
cbar.ax.get_yaxis().set_ticks([])
for s, lab in enumerate(f):
cbar.ax.text( 0.08,(0.95-0.01)/float(len(f)-1) * s, lab, fontsize=8,ha='left')
fname = plot_name+'.'+nplist[0]+'.'+nplist[1]+'.'+nplist[2]+'.pdf'
plt.savefig(fname)
plt.close()
else:
k=j-3
i_subplot = k + 1
fig = plt.figure(1, figsize=(5,5))
ax = fig.add_subplot(3,1,i_subplot)
for i in range(len(f)):
bb=np.where(FILTER[bf]==i)[0]
r=mycolor[i][0]
g=mycolor[i][1]
b=mycolor[i][2]
ax.scatter(wavetable[bb], dd[bb], s=1, color=(r,g,b))
if (k<2):
ax.xaxis.set_major_formatter( NullFormatter() )
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
else:
ax.set_xlabel(r'WL($\AA$)',fontsize=10)
ax.set_ylabel(r'$\Delta$ MAG',fontsize=10)
fig.subplots_adjust(wspace=0,hspace=0)
ax.axhline(y=0,color='k')
ax.set_xlim(1000,9000)
ax.set_ylim(-3,3)
ax.set_xticks(np.linspace(1000, 9000, 16, endpoint=False))
ax.set_yticks(np.linspace(-3, 3, 4, endpoint=False))
ax.text(8500,2.1,nplist[j], {'color': 'k', 'fontsize': 10})
fontsize=8
for tick in ax.xaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
for tick in ax.yaxis.get_major_ticks():
tick.label1.set_fontsize(fontsize)
if (j==5):
cbar_ax = fig.add_axes([0.9, 0.15, 0.05, 0.7])
cbar=plt.colorbar(CS3, cax=cbar_ax, ticks=range(0,len(f)),orientation='vertical')
cbar.ax.get_yaxis().set_ticks([])
for s, lab in enumerate(f):
cbar.ax.text( 0.08,(0.95-0.01)/float(len(f)-1) * s, lab , fontsize=8,ha='left')
fname = plot_name+'.'+nplist[3]+'.'+nplist[4]+'.'+nplist[5]+'.pdf'
plt.savefig(fname)
plt.close()
a=np.loadtxt('calibration.photometry.information.capak.cat')
Z_s=a[:,0]
CWL=a[:,1]
filter_id=a[:,2]
spectral_type=a[:,3]
model_mag=a[:,4]
mag=a[:,5]
plot_name='test'
plot(Z_s,CWL,filter_id,spectral_type,model_mag,mag,plot_name)
you can also download the data from here.
I will appreciate to get any help.
You can use plt.subplots() passing the gridspec_kw parameter to adjust the axes' aspect ratio in a very flexible way, and then select the top axes to include the colorbar.
I've worked on your code simplifying it quite a bit. Furthermore, I've changed many things in your code such as: PEP8, removed repeated calls to plt.savefig()and ax methods. The result is:
import numpy as np
import pylab as plt
from matplotlib import rc, rcParams, colors
rc('text', usetex=True)
rcParams['font.size'] = 10
rcParams['axes.labelsize'] = 8
def plot(Z_s, CWL, filter_id, spectral_type, model_mag, mag, plot_name):
f= ['U38', 'B', 'V', 'R', 'I', 'MB420', 'MB464', 'MB485', 'MB518',
'MB571', 'MB604', 'MB646', 'MB696', 'B753', 'MB815', 'MB856',
'MB914']
wavetable = CWL/(1+Z_s)
dd = model_mag-mag
nplist = ['E', 'Sbc', 'Scd', 'Irr', 'SB3', 'SB2']
minimum, maximum = (0., 16.)
Z = [[0, 0],[0, 0]]
levels = list(np.linspace(0, 1, len(f)+1))
NUM_COLORS = len(f)
cmap = plt.get_cmap('gist_rainbow')
mycolor = []
for i in range(NUM_COLORS):
mycolor.append(cmap(1.*i/NUM_COLORS))
mymap = colors.LinearSegmentedColormap.from_list('mycolors', mycolor)
CS3 = plt.contourf(Z, levels, cmap=mymap)
coords = CS3.get_array()
coords = coords[:-1] + np.diff(coords)/2.
FILTER = filter_id
SED = spectral_type
dummy = 2
xmin = 1000
xmax = 9000
ymin = -3
ymax = 3
fig, axes = plt.subplots(nrows=5, figsize=(5, 6),
gridspec_kw=dict(height_ratios=[0.35, 0.05, 1, 1, 1]))
fig2, axes2 = plt.subplots(nrows=5, figsize=(5, 6),
gridspec_kw=dict(height_ratios=[0.35, 0.05, 1, 1, 1]))
fig.subplots_adjust(wspace=0, hspace=0)
fig2.subplots_adjust(wspace=0, hspace=0)
axes_all = np.concatenate((axes[dummy:], axes2[dummy:]))
dummy_axes = np.concatenate((axes[:dummy], axes2[:dummy]))
for ax in axes_all:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.axhline(y=0, color='k')
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_xticks([])
ax.set_yticks(np.linspace(ymin, ymax, 4, endpoint=False))
ax.set_ylabel(r'$\Delta$ MAG', fontsize=10)
axes[-1].set_xticks(np.linspace(xmin, xmax, 16, endpoint=False))
axes2[-1].set_xticks(np.linspace(xmin, xmax, 16, endpoint=False))
plt.setp(axes[-1].xaxis.get_majorticklabels(), rotation=30)
plt.setp(axes2[-1].xaxis.get_majorticklabels(), rotation=30)
axes[-1].set_xlabel(r'WL($\AA$)', fontsize=10)
axes2[-1].set_xlabel(r'WL($\AA$)', fontsize=10)
for ax in dummy_axes:
for s in ax.spines.values():
s.set_visible(False)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.set_xticks([])
ax.set_yticks([])
for axes_i in [axes, axes2]:
cbar = plt.colorbar(CS3, ticks=[], orientation='horizontal',
cax=axes_i[0])
for s, lab in enumerate(f):
cbar.ax.text(coords[s], 0.5, lab, fontsize=8, va='center',
ha='center', rotation=90,
transform=cbar.ax.transAxes)
for (j, d) in enumerate(nplist):
bf = (SED==j)
if (j<3):
k = j
ax = axes[k+dummy]
ax.text(8500, 2.1, nplist[j], {'color': 'k', 'fontsize': 10})
for i in range(len(f)):
bb = np.where(FILTER[bf]==i)[0]
ax.scatter(wavetable[bb], dd[bb], s=1, color=mycolor[i])
else:
k = j-3
ax = axes2[k+dummy]
ax.text(8500, 2.1, nplist[j], {'color': 'k', 'fontsize': 10})
for i in range(len(f)):
bb = np.where(FILTER[bf]==i)[0]
ax.scatter(wavetable[bb], dd[bb], s=1, color=mycolor[i])
fname = '.'.join([plot_name, nplist[0], nplist[1], nplist[2], 'png'])
fig.savefig(fname)
fname = '.'.join([plot_name, nplist[3], nplist[4], nplist[5], 'png'])
fig2.savefig(fname)
if __name__=='__main__':
a = np.loadtxt('calibration.photometry.information.capak.cat')
Z_s = a[:, 0]
CWL = a[:, 1]
filter_id = a[:, 2]
spectral_type = a[:, 3]
model_mag = a[:, 4]
mag = a[:, 5]
plot_name = 'test'
plot(Z_s, CWL, filter_id, spectral_type, model_mag, mag, plot_name)
which gives: