How to combine two time series plots using python seaborn? - python

I have a training and testing time series dataset that I would like to combine togther, to show how well the forecast did predicting the testing dataset.
Here is the toy code to reproduce the data:
import pandas as pd
import seaborn as sns
train_date = ['2017-01-01T00:00:00.000000000', '2017-02-01T00:00:00.000000000',
'2017-03-01T00:00:00.000000000', '2017-04-01T00:00:00.000000000',
'2017-05-01T00:00:00.000000000', '2017-06-01T00:00:00.000000000',
'2017-07-01T00:00:00.000000000', '2017-08-01T00:00:00.000000000',
'2017-09-01T00:00:00.000000000', '2017-10-01T00:00:00.000000000',
'2017-11-01T00:00:00.000000000', '2017-12-01T00:00:00.000000000',
'2018-01-01T00:00:00.000000000', '2018-02-01T00:00:00.000000000',
'2018-03-01T00:00:00.000000000', '2018-04-01T00:00:00.000000000',
'2018-05-01T00:00:00.000000000', '2018-06-01T00:00:00.000000000',
'2018-07-01T00:00:00.000000000', '2018-08-01T00:00:00.000000000',
'2018-09-01T00:00:00.000000000', '2018-10-01T00:00:00.000000000',
'2018-11-01T00:00:00.000000000', '2018-12-01T00:00:00.000000000',
'2019-01-01T00:00:00.000000000', '2019-02-01T00:00:00.000000000',
'2019-03-01T00:00:00.000000000', '2019-04-01T00:00:00.000000000',
'2019-05-01T00:00:00.000000000', '2019-06-01T00:00:00.000000000',
'2019-07-01T00:00:00.000000000', '2019-08-01T00:00:00.000000000',
'2019-09-01T00:00:00.000000000', '2019-10-01T00:00:00.000000000',
'2019-11-01T00:00:00.000000000', '2019-12-01T00:00:00.000000000',
'2020-01-01T00:00:00.000000000', '2020-02-01T00:00:00.000000000',
'2020-03-01T00:00:00.000000000', '2020-04-01T00:00:00.000000000',
'2020-05-01T00:00:00.000000000', '2020-06-01T00:00:00.000000000',
'2020-07-01T00:00:00.000000000', '2020-08-01T00:00:00.000000000',
'2020-09-01T00:00:00.000000000', '2020-10-01T00:00:00.000000000',
'2020-11-01T00:00:00.000000000', '2020-12-01T00:00:00.000000000']
test_date = ['2021-01-01T00:00:00.000000000', '2021-02-01T00:00:00.000000000',
'2021-03-01T00:00:00.000000000', '2021-04-01T00:00:00.000000000',
'2021-05-01T00:00:00.000000000', '2021-06-01T00:00:00.000000000',
'2021-07-01T00:00:00.000000000', '2021-08-01T00:00:00.000000000',
'2021-09-01T00:00:00.000000000', '2021-10-01T00:00:00.000000000',
'2021-11-01T00:00:00.000000000', '2021-12-01T00:00:00.000000000']
train_eaches = [1915.0, 1597.0, 1533.0, 1601.0, 1585.0, 1675.0, 1760.0, 1910.0, 1886.0, 1496.0, 1545.0, 1538.0, 1565.0, 1350.0,1686.0, 1535.0, 1629.0, 1589.0, 1605.0, 1560.0, 1353.0,1366.0, 1246.0, 1423.0, 1579.0, 1368.0, 1727.0, 1687.0, 1872.0, 1824.0, 2161.0, 1065.0, 727.0, 1567.0, 1509.0, 1687.0, 1647.0,1476.0, 1231.0, 1165.0, 1341.0, 1425.0, 1502.0, 1450.0, 1497.0, 1259.0, 1207.0, 1132.0]
test_eaches = [1252.0, 1038.0, 1184.0, 1200.0, 1219.0, 1339.0, 1504.0, 2652.0, 1724.0, 1029.0,
711.0, 1530.0]
test_predictions = [1914.7225, 1490.4715, 1317.4765, 1341.263375, 1459.5875, 1534.2375, 1467.208875, 1306.2145, 1171.652625, 1120.641, 1138.912, 1171.914125]
test_credibility_down = [1805. , 1303. , 1017. , 915.975, 870.975, 797. ,
657. , 507. , 392. , 320. , 272. , 235. ]
test_credibility_up = [2029.025, 1702. , 1681.025, 1908. , 2329.05 , 2695.025,
2867.075, 2835. , 2815.075, 2949. , 3278.025, 3679. ]
train_df = pd.DataFrame.from_dict({'date':train_date, 'eaches':train_eaches})
test_df = pd.DataFrame.from_dict({'date':test_date, 'eaches':test_eaches, '2.5% Credibilty':test_credibility_down,
'97.5% Credibility':test_credibility_up})
Here are the two plots (train and test) and code that produces those plots:
fig = plt.figure(figsize=(15,4))
c=sns.scatterplot(x =train_df['date'], y = train_df['eaches'], label = 'Train Eaches',
color = 'black')
fig = plt.figure(figsize=(15,4))
a=sns.lineplot(x =test_df['date'], y = test_df['predictions'], label = 'Posterior Prediction', color = 'red')
b=sns.lineplot(x =test_df['date'], y = test_df['2.5% Credibilty'], label = 'Credibilty Interval',
color = 'skyblue', alpha=.3)
c=sns.lineplot(x =test_df['date'], y = test_df['97.5% Credibility'], label = 'Credibilty Interval',
color = 'skyblue', alpha=.3)
line = c.get_lines()
plt.fill_between(line[0].get_xdata(), line[1].get_ydata(), line[2].get_ydata(), color='skyblue', alpha=.3)
sns.scatterplot(x =test_df['date'], y = test_df['eaches'], label = 'True Value', color='black')
plt.legend()
I would like to basically add the two x axis as a continuation and maybe add a vertical line to the start of the test period.

Put them on the same axes and use axvline to mark the prediction start. Also, you can fix the overlapping dates on the x-axis by casting the date columns as proper datetimes (train_df["date"] = pd.to_datetime(train_df.date)).
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(1, 1, figsize=(15,4))
c_train= sns.scatterplot(x =train_df['date'], y = train_df['eaches'], label = 'Train Eaches',
color = 'black', ax=ax)
a = sns.lineplot(x =test_df['date'], y = test_df['predictions'], label = 'Posterior Prediction', color = 'red', ax=ax)
b = sns.lineplot(x =test_df['date'], y = test_df['2.5% Credibilty'], label = 'Credibilty Interval',
color = 'skyblue', alpha=.3, ax=ax)
c = sns.lineplot(x =test_df['date'], y = test_df['97.5% Credibility'], label = 'Credibilty Interval',
color = 'skyblue', alpha=.3)
line = c.get_lines()
ax.fill_between(line[0].get_xdata(), line[1].get_ydata(), line[2].get_ydata(), color='skyblue', alpha=.3)
sns.scatterplot(x =test_df['date'], y = test_df['eaches'], label = 'True Value', color='black', ax=ax)
ax.legend()
ax.axvline(test_df['date'][0])

Related

interactive plot python for DBSCAN clustering

I'm trying to build a code for interactive DBSCAN clustering method. But when I run it I get some errors. But when I'm getting an error says "NameError: name 'train_test_split' is not defined
even I tried to import it using : from sklearn.model_selection import train_test_split
How can I solve to have the interactive plot working proberly?
df_mv = pd.read_csv(r"https://raw.githubusercontent.com/HanaBachi/MachineLearning/main/multishape.csv")
text_trap = io.StringIO()
sys.stdout = text_trap
l = widgets.Text(value=' DBSCAN, Hana Bachi, The University of Texas at Austin',
layout=Layout(width='950px', height='30px'))
eps = widgets.FloatSlider(min=0, max = 2, value=0.1, step = 0.1, description = 'eps',orientation='horizontal', style = {'description_width': 'initial'}, continuous_update=False)
minPts = widgets.FloatSlider(min=0, max = 5, value=1, step = 1, description = 'minPts %',orientation='horizontal',style = {'description_width': 'initial'}, continuous_update=False)
color = ['blue','red','green','yellow','orange','white','magenta','cyan']
style = {'description_width': 'initial'}
ui = widgets.HBox([eps,minPts],)
ui2 = widgets.VBox([l,ui],)
# create activation function plots
def DBSCAN_plot(eps, minPts):
db = DBSCAN(eps=0.155, min_samples=5).fit(df_mv)
labels = db.labels_
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
x = df_mv.values[:,0]
y = df_mv.values[:,1]
cmap = plt.cm.rainbow
#norm = mc.BoundaryNorm(labels, cmap.N)
plt.figure(figsize=(14,7))
plt.scatter(x, y, c=labels, cmap='tab10', s=50)
plt.scatter(x[np.where(labels==-1)], y[np.where(labels==-1)], c='k', marker='x', s=100)
plt.title('DBSCAN of non-spherical data with noise', fontsize = 20)
plt.colorbar()
plt.show()
plt.subplots_adjust(left=0.0, bottom=0.0, right=2.0, top=1.0, wspace=0.2, hspace=0.3)
plt.show()
interactive_plot1 = widgets.interactive_output(DBSCAN_plot, {'eps': eps})
interactive_plot1 = widgets.interactive_output(DBSCAN_plot, {'minPts': minPts})
interactive_plot1.clear_output(wait = True) # reduce flickering by delaying plot updating
# create dashboard/formatting
uia = widgets.HBox([interactive_plot1],)
uia2 = widgets.VBox([eps, uia],)
uib = widgets.HBox([interactive_plot1],)
uib2 = widgets.VBox([minPts, uib],)
display(uib2, interactive_plot) # display the interactive plot
I'm trying to build an interactive plot for DBSCAN clustering

Matplotlib shows <ErrorbarContainer object of 3 artists> in x axes legend

I'm trying to create an errorbar graph showing some data points and their standard deviation. I was able to create exactly what I wanted, the only problem is that <ErrorbarContainer object of 3 artists> is showing up in my legend.
Graph showing <ErrorbarContainer object of 3 artists> instead of nothing
(The data is replicated 3 times because I still haven't finished collecting it)
This is my code:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
dfs_busca_div = [
pd.read_csv(r'./out/busca_hash_aberto_divisao.csv', usecols=[3]),
pd.read_csv(r'./out/busca_hash_aberto_divisao.csv', usecols=[3]),
pd.read_csv(r'./out/busca_hash_aberto_divisao.csv', usecols=[3])]
dfs_busca_mul = [
pd.read_csv(r'./out/busca_hash_aberto_multiplicacao.csv', usecols=[3]),
pd.read_csv(r'./out/busca_hash_aberto_multiplicacao.csv', usecols=[3]),
pd.read_csv(r'./out/busca_hash_aberto_multiplicacao.csv', usecols=[3])]
dfs_busca_pri = [
pd.read_csv(r'./out/busca_hash_aberto_primos.csv', usecols=[3]),
pd.read_csv(r'./out/busca_hash_aberto_primos.csv', usecols=[3]),
pd.read_csv(r'./out/busca_hash_aberto_primos.csv', usecols=[3])]
df_types = [
'Progressive Overflow',
'Duplo',
'Aberto'
]
df_names = [
'Divisão',
'Multiplicação',
'Primo'
]
def med_dsvp(dfs, key):
medias_busca = np.array([])
dsvp_busca = np.array([])
for df in dfs:
media = df[key].mean()
medias_busca = np.append(medias_busca, media)
desv_pad = df[key].std()
dsvp_busca = np.append(dsvp_busca, desv_pad)
return medias_busca, dsvp_busca
medias_busca_div, dsvp_busca_div = med_dsvp(dfs_busca_div, 'TempoBusca')
medias_busca_mul, dsvp_busca_mul = med_dsvp(dfs_busca_mul, 'TempoBusca')
medias_busca_pri, dsvp_busca_pri = med_dsvp(dfs_busca_pri, 'TempoBusca')
x = np.arange(len(df_types))
width = 0.3
fig, ax = plt.subplots()
rects1 = ax.errorbar(x - width/2, medias_busca_div, dsvp_busca_div, fmt = 'o', capsize=3, label='Divisão')
rects2 = ax.errorbar(x, medias_busca_mul, dsvp_busca_mul, fmt = 'o', capsize=3, label='Multiplicação')
rects3 = ax.errorbar(x + width/2, medias_busca_pri, dsvp_busca_pri, fmt = 'o', capsize=3, label='Primo')
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Tempo(s)')
ax.set_title('Tempo de busca hash')
ax.set_xticks(x);ax.set_xticklabels(df_types, minor=False)
ax.set_yticks(np.arange(0.000, 0.250, 0.010))
plt.grid(which='major', axis='y')
ax.legend()
ax.set_xlabel(rects1)
ax.set_xlabel(rects2)
ax.set_xlabel(rects3)
fig.tight_layout()
plt.show()
What can I do in order to stop this weird legend from showing up?

Encountering time out error in the middle of a matplotlib for loop

I have a code which will go through three dictionaries, and make some plots if the keys all match. I've been running into an odd issue due to the use of the matplotlib table.
When I first got this code to run, I had no issues finishing the whole loop. Now I am encountering a time out error by the second iteration
I tried moving the the table out of the for loop.
I added plt.close('all')
I also try importing matplotlib again at the end of the loop in hopes of resetting something in the backend.
for k, v in oct_dict.items():
for k2, v2 in stu_dict.items():
for k3, v3 in oct2_dict.items():
if k == k2 and k == k3:
with PdfPages('{}.pdf'.format(k)) as pdf:
#rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
#v = v[v['a_1920'] != 0]
rc('text', usetex=True)
fig = plt.figure(figsize = (8,10.5))
gs=GridSpec(2,2) # 2 rows, 3 columns
ax0 = fig.add_subplot(gs[0,0])
ax0.bar(x=np.arange(2), height = [float(v['a_1920'])*100, mean_a_1920*100], color = nice)
plt.xticks(np.arange(2), ['{}'.format(k), 'D75'])
for p in ax0.patches:
a =p.get_height()
ax0.annotate('{:0.2f}'.format(float(a)), (p.get_x()+.1, p.get_height() * .75), weight = 'bold')
ax1 = fig.add_subplot(gs[0,1])
c = str(len(v2['student_id']))
c2 = int(v['c_1920'])
props = dict(boxstyle='round', facecolor='white', alpha=0.0)
c3 = int(v['b_1920'])
# place a text box in upper left in axes coords
c4 = int(v['d_1920'])
ax1.text(0.0, 0.95, 'Number of Age : {}'.format(c3), transform=ax1.transAxes, fontsize=12,
verticalalignment='top')
ax1.text(0.0, 0.85, 'Number of Incomplete : {}'.format(c2), transform=ax1.transAxes, fontsize=12,
verticalalignment='top')
ax1.text(0.0, 0.75, 'Number of Invalid : {}'.format(c4), transform = ax1.transAxes, fontsize = 12,
verticalalignment = 'top' )
ax1.text(0.0, 0.65, 'Number of who will reach Age:\n{}'.format(c), transform=ax1.transAxes, fontsize=12,
verticalalignment='top' )
#ax1.table(cellLoc = 'center', cellText = [] , loc = 'upper center')
ax1.axis('off')
ax1.axis('tight')
#fig.suptitle('Monthly Summary', va = 'top', ha= 'center')
fig.text(0.3, 1, 'Monthly Summary '+ dt.date.today().strftime("%b %d, %Y"), fontsize=12, verticalalignment='top', bbox=props)
#plt.subplots_adjust(top = .75)
#plt.tight_layout()
#gs.constrained_layout()
#print(float(v3['inc']))
#print(float(v3['com']))
ax2 = fig.add_subplot(gs[1,0])
plt.sca(ax2)
p1 = plt.bar(np.arange(1), int(v3['com']), width=.25,color = 'b',label = 'Complete')
p2 = plt.bar(np.arange(1), int(v3['inc']), width = .25, bottom = int(v3['com']), color = 'r', label = 'Incomplete')
plt.legend()
for p in ax2.patches:
ax2.annotate((p.get_height()), (p.get_x()+.1, p.get_height() * .75), weight = 'bold')
ax2.set_xticks([])
# # #ax2.set_xlabel='Students Who Will Turn 15'
ax2.set_title('Students who will turn 15 later in the school year')
ax2.set_xticks([])
ax3 = fig.add_subplot(gs[1,1])
a = int(v3['com'])+int(v3['inc'])
ax3.axis('off')
plt.tight_layout()
pdf.savefig()
plt.close('all')
fig = plt.figure(figsize = (8,11.5))
gs=GridSpec(1,1)
axs = fig.add_subplot(gs[0])
cell_text = []
v2 = v2.drop(['Grand Total','birth_dte','loc'],axis = 1)
binarymap = {0:'No',1:'Yes'}
v2['Plan Not Complete'] = v2['Plan Not Complete'].map(binarymap)
v2['Plan Already Complete'] = v2['Plan Already Complete'].map(binarymap)
labels = [six column titles here]
for row in range(len(v2)):
try:
cell_text.append(v2.iloc[row])
except:
pass
table = axs.table(cellLoc = 'center', cellText = cell_text, colLabels = labels,
rowLoc = 'center', colLoc = 'center',loc = 'upper center',fontsize = 32)
table.set_fontsize(32)
table.scale(1, 1.5)
#axs.text(0.0,0.5,'For the column')
axs.axis('off')
pdf.savefig()
#plt.savefig('{}_list.pdf'.format(k))
plt.show()
plt.close('all')
import matplotlib.pyplot as plt
TimeoutError: Lock error: Matplotlib failed to acquire the following lock file:
C:\Users\myusername.matplotlib\tex.cache\23c95fa5c37310802233a994d78d178d.tex.matplotlib-lock
NOTE: If some of the key names dont match in this code it is on purpose, I had to change them up for this post since it is public. The error is thrown by the second iteration once the code reaches the axs.table line.
I got everything to run properly after using the conda command prompt to clean the environments
conda clean --all
Something that works but I would have liked to avoid was just removing the use of tex for this script. rc param tex set to False, code finished running pretty quickly as well

How to manually change the tick labels of the margin plots on a Seaborn jointplot

I am trying to use a log scale as the margin plots for my seaborn jointplot. I am usings set_xticks() and set_yticks(), but my changes do not appear. Here is my code below and the resulting graph:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import seaborn as sns
import pandas as pd
tips = sns.load_dataset('tips')
female_waiters = tips[tips['sex']=='Female']
def graph_joint_histograms(df1):
g=sns.jointplot(x = 'total_bill',y = 'tip', data = tips, space = 0.3,ratio = 3)
g.ax_joint.cla()
g.ax_marg_x.cla()
g.ax_marg_y.cla()
for xlabel_i in g.ax_marg_x.get_xticklabels():
xlabel_i.set_visible(False)
for ylabel_i in g.ax_marg_y.get_yticklabels():
ylabel_i.set_visible(False)
x_labels = g.ax_joint.get_xticklabels()
x_labels[0].set_visible(False)
x_labels[-1].set_visible(False)
y_labels = g.ax_joint.get_yticklabels()
y_labels[0].set_visible(False)
y_labels[-1].set_visible(False)
g.ax_joint.set_xlim(0,200)
g.ax_marg_x.set_xlim(0,200)
g.ax_joint.scatter(x = df1['total_bill'],y = df1['tip'],data = df1,c = 'y',edgecolors= '#080808',zorder = 2)
g.ax_joint.scatter(x = tips['total_bill'],y = tips['tip'],data = tips, c= 'c',edgecolors= '#080808')
ax1 =g.ax_marg_x.get_axes()
ax2 = g.ax_marg_y.get_axes()
ax1.set_yscale('log')
ax2.set_xscale('log')
ax1.set_yscale('log')
ax2.set_xscale('log')
ax2.set_xlim(1e0, 1e4)
ax1.set_ylim(1e0, 1e3)
ax2.xaxis.set_ticks([1e0,1e1,1e2,1e3])
ax2.xaxis.set_ticklabels(("1","10","100","1000"), visible = True)
plt.setp(ax2.get_xticklabels(), visible = True)
colors = ['y','c']
ax1.hist([df1['total_bill'],tips['total_bill']],bins = 10, stacked=True,log = True,color = colors, ec='black')
ax2.hist([df1['tip'],tips['tip']],bins = 10,orientation = 'horizontal', stacked=True,log = True,color = colors, ec='black')
ax2.set_ylabel('')
Any ideas would be much appreciated.
Here is the resulting graph:
You should actually get an error from the line g.ax_marg_y.get_axes() since an axes does not have a get_axes() method.
Correcting for that
ax1 =g.ax_marg_x
ax2 = g.ax_marg_y
should give you the desired plot. The ticklabels for the log axis are unfortunately overwritten by the histogram's log=True argument. So you can either leave that out (since you already set the axes to log scale anyways) or you need to set the labels after calling hist.
import matplotlib.pyplot as plt
import seaborn as sns
tips = sns.load_dataset('tips')
def graph_joint_histograms(tips):
g=sns.jointplot(x = 'total_bill',y = 'tip', data = tips, space = 0.3,ratio = 3)
g.ax_joint.cla()
g.ax_marg_x.cla()
g.ax_marg_y.cla()
for xlabel_i in g.ax_marg_x.get_xticklabels():
xlabel_i.set_visible(False)
for ylabel_i in g.ax_marg_y.get_yticklabels():
ylabel_i.set_visible(False)
x_labels = g.ax_joint.get_xticklabels()
x_labels[0].set_visible(False)
x_labels[-1].set_visible(False)
y_labels = g.ax_joint.get_yticklabels()
y_labels[0].set_visible(False)
y_labels[-1].set_visible(False)
g.ax_joint.set_xlim(0,200)
g.ax_marg_x.set_xlim(0,200)
g.ax_joint.scatter(x = tips['total_bill'],y = tips['tip'],data = tips,
c = 'y',edgecolors= '#080808',zorder = 2)
g.ax_joint.scatter(x = tips['total_bill'],y = tips['tip'],data = tips,
c= 'c',edgecolors= '#080808')
ax1 =g.ax_marg_x
ax2 = g.ax_marg_y
ax1.set_yscale('log')
ax2.set_xscale('log')
ax2.set_xlim(1e0, 1e4)
ax1.set_ylim(1e0, 1e3)
ax2.xaxis.set_ticks([1e0,1e1,1e2,1e3])
ax2.xaxis.set_ticklabels(("1","10","100","1000"), visible = True)
plt.setp(ax2.get_xticklabels(), visible = True)
colors = ['y','c']
ax1.hist([tips['total_bill'],tips['total_bill']],bins = 10,
stacked=True, color = colors, ec='black')
ax2.hist([tips['tip'],tips['tip']],bins = 10,orientation = 'horizontal',
stacked=True, color = colors, ec='black')
ax2.set_ylabel('')
graph_joint_histograms(tips)
plt.show()

Matplotlib tilted text on log scale?

Either I cannot figure it out, or there is a bug in matplotlib when drawing text() with rotation to loglog or xscale('log'); yscale('log'); plot.
my code looks like this:
from pylab import *
# =================== deltaV-vs-trust
figure( figsize=( 12, 8 ) )
times = array( [ 1.0, 60.0, 3600.0, 86400.0, 604800, 2592000, 31556926, 315569260, 3155692600 ] )
timeText = [ 'sec','min','hour', 'day', 'week', 'month', 'year', '10years', '100years' ]
dists = array( [ 1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 6371e+3, 42164e+3, 384400e+3, 1e+9, 1e+10, 5.790918E+010, 1.082089E+011, 1.495979E+011, 2.279366E+011, 7.784120E+011, 1.426725E+012, 2.870972E+012, 4.498253E+012, 1.40621998e+13, 2.99195741e+14, 7.47989354e+15, 4.13425091e+16 ] )
distText = [ '10m','100m', '1km','10km', '100km', '1000km', 'LEO', 'GEO', 'Moon', r'10$^6$km',r'10$^7$km', 'Mercury', 'Venus', 'Earth', 'Mars', 'Jupiter', 'Satrun', 'Uranus', 'Neptune', 'Heliopause', 'Inner Oorth', 'Outer Oorth', 'Alpha Centauri' ]
vMin = 1e+0; vMax = 1e+8;
aMin = 1e-4; aMax = 1e+2;
As = linspace(aMin,aMax,2); print As
Vs = linspace(vMin,vMax,2); print Vs
As_ = As.copy(); Vs_ = Vs.copy()
for i in range(len(dists)):
dist = dists[i]
ts = sqrt( 2*dist / As )
As_[:] = As [:]
Vs_ = As * ts
if( Vs_[0] < Vs[0] ):
Vs_[0] = Vs[0]
As_[0] = Vs_[0]**2 / (2*dist)
plot( Vs_, As_, 'b-', alpha=0.5 )
plt.text( Vs_[0], As_[0], distText[i], rotation=60, color='b', horizontalalignment='center', verticalalignment='bottom') # this does not work properly
#plt.text( Vs_[0], As_[0], distText[i], rotation=60, color='b', horizontalalignment='center', verticalalignment='center') # this works but does not look nice
#plt.text( Vs_[0], 1.5*As_[0], distText[i], rotation=60, color='b', horizontalalignment='center', verticalalignment='center') # a bit better
for i in range(len(times)):
time = times[i]
As_[:] = As[:]
Vs_ = As * time
if( Vs_[1] > Vs[1] ):
Vs_[1] = Vs[1]
As_[1] = Vs_[1] / time
plot( Vs_, As_, 'r-', alpha=0.5 )
plt.text( Vs_[1], As_[1], timeText[i]+" ", rotation=40, color='r', horizontalalignment='right', verticalalignment='baseline')
ylabel( r" acceleration [m/s$^2$] " )
xlabel( r" delta-v [m/s ] " )
yscale('log')
xscale('log')
grid()
ylim( aMin, aMax )
xlim( vMin, vMax )
show()
the result looks like this ( you can see how text is not possitioned correctly on the corresponding line; very visible for Inner Oorth,Outer Oorth and Alpha Centauri ) :
I think this problem is visible only for large rotation angles. If I use horizontalalignment='center', verticalalignment='center' it works properly, however it does not looks nice ( because the line cross the text and the text cross the border of image )
just for context - what I'm trying to make is plot like this:
http://www.projectrho.com/public_html/rocket/images/enginelist/torchChart.jpg
The key is the 'rotation_mode' kwarg to text (doc) which I did not even know existed until tonight. What this controls is if the text is rotated, then aligned (the default) or aligned, then rotated.
In trying to understand your question I cleaned your code up to (more or less) conform to pep8 and simplified some of the computation.
import matplotlib.pyplot as plt
import numpy as np
# =================== deltaV-vs-trust
times = np.array([1.0, 60.0, 3600.0, 86400.0, 604800, 2592000, 31556926, 315569260, 3155692600])
timeText = ['sec','min','hour', 'day', 'week', 'month', 'year', '10years', '100years']
dists= np.array([1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 6371e+3, 42164e+3,
384400e+3, 1e+9, 1e+10, 5.790918E+010, 1.082089E+011, 1.495979E+011,
2.279366E+011, 7.784120E+011, 1.426725E+012, 2.870972E+012,
4.498253E+012, 1.40621998e+13, 2.99195741e+14, 7.47989354e+15,
4.13425091e+16])
distText = ['10m','100m', '1km','10km', '100km', '1000km', 'LEO',
'GEO', 'Moon', r'10$^6$km',r'10$^7$km', 'Mercury', 'Venus', 'Earth',
'Mars', 'Jupiter', 'Satrun', 'Uranus', 'Neptune', 'Heliopause',
'Oorth', 'Outer Oorth', 'Alpha Centauri']
vMin, vMax = 1e+0, 1e+8
aMin, aMax = 1e-4, 1e+2
As = np.linspace(aMin, aMax, 2)
fig, ax = plt.subplots(figsize=(12, 8))
for dist, text in zip(dists, distText):
# compute the line
v = np.sqrt(2*dist * As)
ax.plot(v, As, 'b-', alpha=0.5)
# sort out where the label should be
txt_y = aMin
txt_x = v[0]
# clip to the edges
if (txt_x < vMin):
txt_x = vMin
txt_y = vMin**2 / (2*dist)
ax.text(txt_x, txt_y, text,
rotation=60, color='b', rotation_mode='anchor',
horizontalalignment='left',
verticalalignment='bottom')
for time, txt in zip(times, timeText):
# compute the line
x = As * time
ax.plot(x, As, 'r-', alpha=0.5)
# sort out where the label should be
txt_x = x[-1]
txt_y = aMax
# clip to the edges
if(txt_x > vMax):
txt_x = vMax
txt_y = vMax / time
ax.text(txt_x, txt_y, txt,
rotation=40, color='r',
horizontalalignment='right', rotation_mode='anchor',
verticalalignment='baseline')
ax.set_ylabel(r"acceleration [m/s$^2$]")
ax.set_xlabel(r"delta-v [m/s]")
ax.set_yscale('log')
ax.set_xscale('log')
ax.grid()
ax.set_ylim(aMin, aMax)
ax.set_xlim(vMin, vMax)

Categories

Resources