Animating and exporting the cartopy Nightshade feature - python

I am trying to plot the groundtrack of a satellite through a combination of packages, animate the satellite movement, mark a field of view from the subsatellite point (which is just arbitrary circles in this code) and then export the file as a video of some kind. So far, I have been able to do all of this except that when I try to export the video, the Nightshade feature doesn't animate so much as overlay and eventually blacks out most of the screen. Is there something I'm missing on how to properly animate the Nightshade feature? I know that I'm essentially recreating a new feature inside the update function everytime it runs a frame but I could not figure out how to update it as I do the scatter plots.
I've included my sample code below.
import pandas as pd
from sgp4.api import WGS72
from sgp4.api import Satrec
from skyfield.api import EarthSatellite, load, N, W, wgs84
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as animation
from cartopy.feature.nightshade import Nightshade
# CREATE THE SATELLITE DATA
epoch = datetime.date(1949, 12, 31)
sat = Satrec()
sat.sgp4init(
WGS72, # gravity model
'i', # 'a' = old AFSPC mode, 'i' = improved mode
5, # satnum: Satellite number
(datetime.date.today() - epoch).days, # epoch: days since 1949 December 31 00:00 UT
0, # bstar: drag coefficient (1/earth radii)
6.969196665e-13, # ndot (NOT USED): ballistic coefficient (revs/day)
0.0, # nddot (NOT USED): mean motion 2nd derivative (revs/day^3)
0.1, # ecco: eccentricity
280 * np.pi / 180, # argpo: argument of perigee (radians)
50 * np.pi / 180, # inclo: inclination (radians)
275 * np.pi / 180, # mo: mean anomaly (radians)
0.0472294454407, # no_kozai: mean motion (radians/minute)
50 * np.pi / 180, # nodeo: right ascension of ascending node (radians)
)
# DEFINE A FEW BASIC PARAMETERS FOR THE PROGRAM
P = sat.mo / sat.no_kozai # min, period of orbit. LEOs orbit between 84-127 minutes
ts = load.timescale()
sat1 = EarthSatellite.from_satrec(sat, ts)
hours = np.arange(0, 6, 0.05)
time = ts.utc(2021, 6, 31, hours)
pos = sat1.at(time).position.km
pos_ec = sat1.at(time).ecliptic_position().km
sp = wgs84.subpoint(sat1.at(time))
latitude = sp.latitude
longitude = sp.longitude
elev = sp.elevation
# CREATE A DATAFRAME OF THE DATA FOR REVIEW LATER IF NEEDED
df = pd.DataFrame([time.utc_datetime(), latitude.degrees, longitude.degrees, elev.km],
index=['DTS', 'lat', 'lon', 'elev']).T
df.lat = df.lat.astype('float32')
df.lon = df.lon.astype('float32')
df.elev = df.elev.astype('float32')
df.set_index('DTS', inplace=True)
# ASSIGN RELEVANT DATA FOR THE SUBSATELLITE POINT
ssp = np.transpose(np.array([longitude.degrees, latitude.degrees]))
line = ssp.copy()
pos = np.where(np.diff(np.abs(line[:, 0] >= 0)))[0]
line[pos, :] = np.nan
# CREATE DATE TIME RANGES FOR USE WITH THE NIGHTSHADE FEATURE
base = datetime.datetime(2000, 1, 1)
dates = np.array([base + datetime.timedelta(hours=i) for i in range(len(hours))])
shades = [Nightshade(date, alpha=0.2) for date in dates]
### CREATE FIGURE AND IMAGE
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.stock_img()
# plot lines that will show the ground track that will be animated
ax.plot(line[:, 0], line[:, 1], '--k')
# create a blank scatter to start
scatter = ax.scatter(None, None, color='r', s=30)
# initiate the circles around the scatter point
circle1 = plt.Circle((longitude.degrees[0], latitude.degrees[0]), radius=30, color='blue', alpha=0.3)
circle2 = plt.Circle((longitude.degrees[0], latitude.degrees[0]), radius=40, color='yellow', alpha=0.3)
# add the circles to the axis
ax.add_patch(circle1)
ax.add_patch(circle2)
# Add the nightshade feature (but set it to be invisible so it doesn't stay through the whole animation)
ns = ax.add_feature(Nightshade(base, alpha=0.0))
# Create all the updates for the animation
def update(i):
lon = ssp[i, 0]
lat = ssp[i, 1]
scatter.set_offsets(np.c_[lon, lat])
# add a feature for the next Nightshade feature
ns = ax.add_feature(shades[i], alpha=0.2)
circle1.center = (lon, lat)
circle2.center = (lon, lat)
return scatter, circle1, circle2, ns
# Run the animation
anim = animation.FuncAnimation(plt.gcf(), update, frames=df.shape[0],init_func=None, interval=250, blit=True)
plt.show()
# WRITE THE VIDEO
Writer = animation.writers['ffmpeg']
writer = Writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
anim.save('gt.mp4', writer=writer)

Related

Scatter data not overlaying properly on radar data... cartopy issue

I'm trying to plot scatter data of storm reports on top of radar gridded data and I seem to be getting strange plotting issues related to mapping using cartopy. See example image attached. It appears that the scatter data plots on a separate axis than the radar data, but I'm not sure why given that the plotting module for the radar data uses the same user input for min/max lat/lon and the chosen projection. Additionally, the lat/lon range on the map is dynamic as I loop through time stamps. I know I can use an ax.set_extent to create fixed coordinates, but this does not solve my issue of the plotting being done on a separate axis. Does anyone have any suggestions? They should overlay on the same axis.
Here is the code:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.dates as mdates
import cartopy.crs as ccrs
import pyart
import pandas as pd
import nexradaws
import tempfile
import pytz
templocation = tempfile.mkdtemp()
import cartopy.feature as cfeature
from metpy.plots import USCOUNTIES
### Define the radar, start time and end time
radar_id = 'KDVN'
start = pd.Timestamp(2020,8,10,16,30).tz_localize(tz='UTC')
end = pd.Timestamp(2020,8,10,21,0).tz_localize(tz='UTC')
### Bounds of map we want to plot
min_lon = -93.25
max_lon = -88.
min_lat = 40.35
max_lat = 43.35
# ### Bounds of map we want to plot
# min_lon = -80.8
# max_lon = -77.
# min_lat = 34
# max_lat = 37
#### and get the data
conn = nexradaws.NexradAwsInterface()
scans = conn.get_avail_scans_in_range(start, end, radar_id)
print("There are {} scans available between {} and {}\n".format(len(scans), start, end))
print(scans[0:4])
## download these files
#results = conn.download(scans[0:2], templocation)
results = conn.download(scans, templocation)
#%%
#Now get the severe reports from the SPC site. This assumes you're plotting a year far #enough in the past that
# SPC has official records available. If plotting a more recent time period, then the #local storm reports archive
#[![enter image description here][1]][1] at IEM is a good source
### wind reports
wind_rpts = pd.read_csv("https://www.spc.noaa.gov/wcm/data/"+str(start.year)+"_wind.csv")
wind_rpts['datetime'] = pd.to_datetime(wind_rpts.date + ' ' + wind_rpts.time) ## convert to datetime
wind_rpts.set_index("datetime",inplace=True)
### times in the file are given in central standard time (UTC+6). Localize, and convert to UTC
wind_rpts.index = wind_rpts.index.tz_localize("Etc/GMT+6",ambiguous='NaT',nonexistent='shift_forward').tz_convert("UTC")
## subset down to 30 minutes before/after the radar times we're plotting
wind_rpts = wind_rpts[((start-pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M")):((end+pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M"))]
wind_rpts
### repeat for tornado reports
tor_rpts = pd.read_csv("https://www.spc.noaa.gov/wcm/data/"+str(start.year)+"_torn.csv")
tor_rpts['datetime'] = pd.to_datetime(tor_rpts.date + ' ' + tor_rpts.time) ## convert to datetime
tor_rpts.set_index("datetime",inplace=True)
### times in the file are given in central standard time (UTC+6). Localize, and convert to UTC
tor_rpts.index = tor_rpts.index.tz_localize("Etc/GMT+6",ambiguous='NaT',nonexistent='shift_forward').tz_convert("UTC")
## subset down to 30 minutes before/after the radar times we're plotting
tor_rpts = tor_rpts[((start-pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M")):((end+pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M"))]
tor_rpts
### repeat for hail
hail_rpts = pd.read_csv("https://www.spc.noaa.gov/wcm/data/"+str(start.year)+"_hail.csv")
hail_rpts['datetime'] = pd.to_datetime(hail_rpts.date + ' ' + hail_rpts.time) ## convert to datetime
hail_rpts.set_index("datetime",inplace=True)
### times in the file are given in central standard time (UTC+6). Localize, and convert to UTC
hail_rpts.index = hail_rpts.index.tz_localize("Etc/GMT+6",ambiguous='NaT',nonexistent='shift_forward').tz_convert("UTC")
## subset down to 30 minutes before/after the radar times we're plotting
hail_rpts = hail_rpts[((start-pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M")):((end+pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M"))]
hail_rpts
#%%
'''Now we plot the maps and animate'''
### loop over the radar images that have been downloaded
for i,scan in enumerate(results.iter_success(),start=1):
#for i in range(0,1):
## skip the files ending in "MDM"
if scan.filename[-3:] != "MDM":
print(str(i))
print("working on "+scan.filename)
this_time = pd.to_datetime(scan.filename[4:17], format="%Y%m%d_%H%M").tz_localize("UTC")
radar = scan.open_pyart()
#display = pyart.graph.RadarDisplay(radar)
fig = plt.figure(figsize=[15, 7])
map_panel_axes = [0.05, 0.05, .4, .80]
x_cut_panel_axes = [0.55, 0.10, .4, .25]
y_cut_panel_axes = [0.55, 0.50, .4, .25]
projection = ccrs.PlateCarree()
## apply gatefilter (see here: https://arm-doe.github.io/pyart/notebooks/masking_data_with_gatefilters.html)
#gatefilter = pyart.correct.moment_based_gate_filter(radar)
gatefilter = pyart.filters.GateFilter(radar)
# Lets remove reflectivity values below a threshold.
gatefilter.exclude_below('reflectivity', -2.5)
display = pyart.graph.RadarMapDisplay(radar)
### set up plot
ax1 = fig.add_axes(map_panel_axes, projection=projection)
# Add some various map elements to the plot to make it recognizable.
ax1.add_feature(USCOUNTIES.with_scale('500k'), edgecolor="gray", linewidth=0.4)
#ax1.coastlines('50m', edgecolor='black', linewidth=0.75)
ax1.add_feature(cfeature.STATES.with_scale('10m'), linewidth=1.0)
cf = display.plot_ppi_map('reflectivity', 0, vmin=-7.5, vmax=65,
min_lon=min_lon, max_lon=max_lon, min_lat=min_lat, max_lat=max_lat,
title=radar_id+" reflectivity and severe weather reports, "+this_time.strftime("%H%M UTC %d %b %Y"),
projection=projection, resolution='10m',
gatefilter=gatefilter,
cmap='pyart_HomeyerRainbow',
colorbar_flag=False,
lat_lines=[0,0], lon_lines=[0,0]) ## turns off lat/lon grid lines
#display.plot_crosshairs(lon=lon, lat=lat)
## plot horizontal colorbar
display.plot_colorbar(cf,orient='horizontal', pad=0.07)
# Plot range rings if desired
#display.plot_range_ring(25., color='gray', linestyle='dashed')
#display.plot_range_ring(50., color='gray', linestyle='dashed')
#display.plot_range_ring(100., color='gray', linestyle='dashed')
ax1.set_xticks(np.arange(min_lon, max_lon, .5), crs=ccrs.PlateCarree())
ax1.set_yticks(np.arange(min_lat, max_lat, .5), crs=ccrs.PlateCarree())
## add marker points for severe reports
wind_rpts_now = wind_rpts[((start-pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M")):this_time.strftime("%Y-%m-%d %H:%M")]
ax1.scatter(wind_rpts_now.slon.values.tolist(), wind_rpts_now.slat.values.tolist(), s=20, facecolors='none', edgecolors='mediumblue', linewidths=1.8)
tor_rpts_now = tor_rpts[((start-pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M")):this_time.strftime("%Y-%m-%d %H:%M")]
ax1.scatter(tor_rpts_now.slon.values.tolist(), tor_rpts_now.slat.values.tolist(), s=20, facecolors='red', edgecolors='black', marker="v",linewidths=1.5)
hail_rpts_now = hail_rpts[((start-pd.Timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M")):this_time.strftime("%Y-%m-%d %H:%M")]
ax1.scatter(hail_rpts_now.slon.values.tolist(), hail_rpts_now.slat.values.tolist(), s=20, facecolors='none', edgecolors='green', linewidths=1.8)
plt.savefig(scan.radar_id+"_"+scan.filename[4:17]+"_dz_rpts.png",bbox_inches='tight',dpi=300,
facecolor='white', transparent=False)
#plt.show()
plt.close('all')

How to Plot 2 Lines on Log X-axis in Python?

I'm trying to plot blackbody wavelength vs flux for 288 Kelvin temperature (the Earth) and 6000 Kelvin temperature (the sun). I want both of these to be on the same plot and know I will need a log x-axis but I keep having issues having both lines appear. This is the code I have so far:
# Import libraries
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
# Constants
c = 3.0e8 # m/s
h = 6.626e-34 # Js
k = 1.38e-23 # J/K
c1 = 2*np.pi*h*c**2
c2 = (h*c)/k
T1 = 6000
T2 = 288
lam = np.logspace(-8,-3,2000) # Generate x-axis values
F1 = c1/(lam**5*(np.exp(c2/(lam*T1))-1)) # Calculate y-values
F1 = F1/1e9
F2 = c1/(lam**5*(np.exp(c2/(lam*T2))-1)) # Calculate y-values
F2 = F2/1e9
# Create plot
ax = plt.gca()
plt.xlabel(r'$\lambda$ (nm)')
plt.ylabel(r'$F_{BB\lambda}(W\/m^{-2}nm^{-1})$')
plt.text(0.05,.8, 'T = {0:d}K'.format(T1), transform = ax.transAxes, size = 'small')
plt.text(0.05,.3, 'T = {0:d}K'.format(T2), transform = ax.transAxes, size = 'small')
plt.xticks(), plt.yticks()
plt.semilogx(lam*1e9, F1, lam*1e9, F2, color= 'black') # Create figure and axis objects
plt.xlim(10,1e6)
plt.ylim(0,)
plt.show() # Display plot to screen
This plots the attached picture which is correct for 6000K but for some reason it's not plotting the 288K curve and I'm not sure how to fix it.

How to plot 2 animated graphics with x axis fixed in matplotlib with subplots?

The purpose of the program: I need to plot a signal graphic on the top and a spectrum graphic of this signal on the bottom, only the y data in both cases varies.
I generate a sine wave with a random noise on the input and plot it on the top, that's working perfecly.
The problem is when I try to plot the spectrum graph. It's not updating for some reason and I didn't understand very well the functioning of matplotlib.animation.FuncAnimation.
The code:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
dt = 0.01
Fs = 44000.0 # sample rate
timestep = 1.0/Fs # sample spacing (1/sample rate)
t = np.arange(0, 10, dt) # t range
n = 256 # size of the array data
w = 10000 # frequency of the input
data = np.sin(2*np.pi*w*t)
def update(data):
# update the curves with the incoming data
line.set_ydata(data)
#line2.set_ydata(magnitude)
return line,
def generateData():
# simulate new data coming in
while True:
nse = np.random.randn(len(t))
r = np.exp(-t/0.05)
cnse = np.convolve(nse, r)*dt
cnse = cnse[:len(t)]
data = np.sin(2*np.pi*w*(t)) + cnse
magnitude = np.fft.fft(data)/n
magnitude = np.abs(magnitude[range(n//2)])
yield data
fig = plt.figure()
# plot time graph axis
timeGraph = plt.subplot(2, 1, 1)
timeGraph.set_ylim(-0.2, 0.2)
timeGraph.set_xlabel('Time')
timeGraph.set_ylabel('Amplitude')
# plot frequency graph axis
freqGraph = plt.subplot(2, 1, 2)
freqGraph.set_xlabel('Freq (Hz)')
freqGraph.set_ylabel('|Y(freq)|')
# get frequency range
n = len(data) # length of the signal
print(len(data))
k = np.arange(n)
T = n/Fs
freq = k/T # two sides frequency range
freq = freq[range(n//2)] # one side frequency range
# fft computing and normalization
magnitude = np.fft.fft(data)/n
magnitude = np.abs(magnitude[range(n//2)])
line, = timeGraph.plot(np.linspace(0, 1, len(t)), 'b')
line2, = freqGraph.plot(freq, magnitude, 'g')
# animate the curves
ani = animation.FuncAnimation(fig, update, generateData,
interval=10, blit=True)
plt.show() # open window
Bonus: how do I initialize data and magnitude correctly?
In order for both the time and frequency graph to update, you need to set the data from both to the respective plots in the update function. Of course you also need to provide this data in the generating function. So the generating function should yield (data, magnitude) and the updating function should accept this tuple as input.
It is also a good idea to set some limits for the frequency graph, freqGraph.set_ylim([0, 0.006]) such that it will not stay empty.
I do not know what you mean by how do i initialize data and magnitude correctly?. I think they are initialized correctly in the sense that they are calculated for every frame including the very first one.
Here is a working code.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
dt = 0.01
Fs = 44000.0 # sample rate
timestep = 1.0/Fs # sample spacing (1/sample rate)
t = np.arange(0, 10, dt) # t range
n = 256 # size of the array data
w = 10000 # frequency of the input
data = np.sin(2*np.pi*w*t)
def update(data):
# update the curves with the incoming data
line.set_ydata(data[0])
line2.set_ydata(data[1])
return line, line2,
def generateData():
# simulate new data coming in
while True:
nse = np.random.randn(len(t))
r = np.exp(-t/0.05)
cnse = np.convolve(nse, r)*dt
cnse = cnse[:len(t)]
data = np.sin(2*np.pi*w*(t)) + cnse
magnitude = np.fft.fft(data)/n
magnitude = np.abs(magnitude[range(n//2)])
yield (data, magnitude)
fig = plt.figure()
# plot time graph axis
timeGraph = plt.subplot(2, 1, 1)
timeGraph.set_ylim(-0.2, 0.2)
timeGraph.set_xlabel('Time')
timeGraph.set_ylabel('Amplitude')
# plot frequency graph axis
freqGraph = plt.subplot(2, 1, 2)
freqGraph.set_ylim([0, 0.006])
freqGraph.set_xlabel('Freq (Hz)')
freqGraph.set_ylabel('|Y(freq)|')
# get frequency range
n = len(data) # length of the signal
print(len(data))
k = np.arange(n)
T = n/Fs
freq = k/T # two sides frequency range
freq = freq[range(n//2)] # one side frequency range
# fft computing and normalization
magnitude = np.fft.fft(data)/n
magnitude = np.abs(magnitude[range(n//2)])
line, = timeGraph.plot(np.linspace(0, 1, len(t)),'b')
line2, = freqGraph.plot(freq, magnitude,'g')
# animate the curves
ani = animation.FuncAnimation(fig, update, generateData,
interval = 10, blit=True)
plt.show() # open window

How do I reliably scale matplotlib pcolormesh plots for large data sets?

I'm trying to plot some data using a pcolormesh from the matplotlib.pyplot but I'm having some difficulty when saving the output (specifically, in scaling the image appropriately).
I'm using Python v3.4 with matplotlib v1.51 if that makes a difference.
This is what my code currently looks like:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
def GetData(data_entries, num_of_channels):
data_dict = {'timestamp' : np.linspace(1, data_entries*21, data_entries, endpoint=True)}
for chan in range(0, num_of_channels, 1):
data_dict['random%03d'%chan] = np.random.rand(data_entries, 1).flatten()
num_at_each_end_to_highlight = 10
data_dict['random%03d'%chan][0:num_at_each_end_to_highlight] = 1.5
data_dict['random%03d'%chan][-num_at_each_end_to_highlight:] = 1.5
for chan in range(0, num_of_channels, 1):
data_dict['periodic%03d' % chan] = np.zeros(data_entries)#.flatten()
data_dict['periodic%03d' % chan][::65] = 5000
return pd.DataFrame(data_dict)
def GetSubPlotIndex(totalRows, totalCols, row):
return totalRows*100+totalCols*10+row
def PlotData(df, num_of_channels, field_names):
# Calculate the range of data to plot
data_entries = len(df.index)
# Create the x/y mesh that the data will be plotted on
x = df['timestamp']
y = np.linspace(0, num_of_channels - 1, num_of_channels)
X,Y = np.meshgrid(x,y)
# Iterate through all of the field types and produce one plot for each but share the X axis
for idx, field_name in enumerate(field_names):
# Create this sub-plot
subPlotIndex = GetSubPlotIndex(len(field_names), 1, idx + 1)
ax = plt.subplot(subPlotIndex)
if idx is 0:
ax.set_title('Raw Data Time Series')
# Set the axis scale to exactly meet the limits of the data set.
ax.set_autoscale_on(False)
plt.axis([x[0], x[data_entries-1], 0, num_of_channels - 1])
# Set up the colour palette used to render the data.
# Make bad results (those that are masked) invisible so the background shows instead.
palette = plt.cm.get_cmap('autumn')
palette.set_bad(alpha=0.0)
ax.set_axis_bgcolor('black') # Set the background to zero
# Grab the data and transpose it so we can stick it in the time series running along the X axis.
firstFftCol = df.columns.get_loc(field_name + "%03d"%(0))
lastFftCol = df.columns.get_loc(field_name + "%03d"%(num_of_channels - 1))
data = df.ix[:,firstFftCol:lastFftCol]
data = data.T # Transpose so that time runs along the X axis and bin index is on the Y
# Mask off data with zero's so that it doesn't obscure the data we're actually interested in.
data = np.ma.masked_where(data == 0.0, data)
# Actually create the data mesh so we can plot it
z_min, z_max = data.min().min(), data.max().max()
p = ax.pcolormesh(X,Y, data, cmap=palette, vmin=z_min, vmax=z_max)
# Render it
plt.plot()
# Label the plot and add a key
plt.ylabel(field_name)
plt.colorbar(p)
# Label the plot
plt.xlabel('Time (ms)')
# Record the result
plt.savefig('test.png', edgecolor='none', transparent=False)
if __name__ == '__main__':
data_entries = 30000 # Large values here cause issues
num_of_channels = 255
fields_to_plot = ('random', 'periodic')
data = GetData(data_entries, num_of_channels)
width_in_pixels = len(data.index)+200
additional_vertical_space_per_plot = 50
num_of_plots = len(fields_to_plot)
height_in_pixels = (num_of_channels+additional_vertical_space_per_plot)*num_of_plots
dpi = 80 # The default according to the documentation.
fig = plt.figure(1,figsize=(width_in_pixels/dpi, height_in_pixels/dpi), dpi=dpi)
PlotData(data, num_of_channels, fields_to_plot)
With 1000 entries, the result looks fine:
If I increase the number of samples to the sort of size I actually want to plot (30000), the image is the correct size (30200 pixels wide) but I see a lot of dead space. This is a zoomed-out summary of the issues I see:
Is there a way to more accurately fill the image with the data?
Thanks to the prompt from #Dusch, this seems to solve things rather neatly:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
def GetData(data_entries, num_of_channels):
data_dict = {'timestamp' : np.linspace(1, data_entries*21, data_entries, endpoint=True)}
for chan in range(0, num_of_channels, 1):
data_dict['random%03d'%chan] = np.random.rand(data_entries, 1).flatten()
num_at_each_end_to_highlight = 10
data_dict['random%03d'%chan][0:num_at_each_end_to_highlight] = 1.5
data_dict['random%03d'%chan][-num_at_each_end_to_highlight:] = 1.5
for chan in range(0, num_of_channels, 1):
data_dict['periodic%03d' % chan] = np.zeros(data_entries)#.flatten()
data_dict['periodic%03d' % chan][::65] = 5000
return pd.DataFrame(data_dict)
def GetSubPlotIndex(totalRows, totalCols, row):
return totalRows*100+totalCols*10+row
def PlotData(df, num_of_channels, field_names):
# Calculate the range of data to plot
data_entries = len(df.index)
# Create the x/y mesh that the data will be plotted on
x = df['timestamp']
y = np.linspace(0, num_of_channels - 1, num_of_channels)
X,Y = np.meshgrid(x,y)
# Iterate through all of the field types and produce one plot for each but share the X axis
for idx, field_name in enumerate(field_names):
# Create this sub-plot
subPlotIndex = GetSubPlotIndex(len(field_names), 1, idx + 1)
ax = plt.subplot(subPlotIndex)
if idx is 0:
ax.set_title('Raw Data Time Series')
# Set the axis scale to exactly meet the limits of the data set.
ax.set_autoscale_on(False)
plt.axis([x[0], x[data_entries-1], 0, num_of_channels - 1])
# Set up the colour palette used to render the data.
# Make bad results (those that are masked) invisible so the background shows instead.
palette = plt.cm.get_cmap('autumn')
palette.set_bad(alpha=0.0)
ax.set_axis_bgcolor('black') # Set the background to zero
# Grab the data and transpose it so we can stick it in the time series running along the X axis.
firstFftCol = df.columns.get_loc(field_name + "%03d"%(0))
lastFftCol = df.columns.get_loc(field_name + "%03d"%(num_of_channels - 1))
data = df.ix[:,firstFftCol:lastFftCol]
data = data.T # Transpose so that time runs along the X axis and bin index is on the Y
# Mask off data with zero's so that it doesn't obscure the data we're actually interested in.
data = np.ma.masked_where(data == 0.0, data)
# Actually create the data mesh so we can plot it
z_min, z_max = data.min().min(), data.max().max()
p = ax.pcolormesh(X,Y, data, cmap=palette, vmin=z_min, vmax=z_max)
# Render it
plt.plot()
# Label this sub-plot
plt.ylabel(field_name)
# Sort out the color bar
fig = plt.gcf()
image_width = fig.get_size_inches()[0] * fig.dpi # size in pixels
colorbar_padding_width_in_pixels = 20
colorbar_padding = colorbar_padding_width_in_pixels/image_width
plt.colorbar(p, pad=colorbar_padding)
# Label the plot
plt.xlabel('Time (ms)')
# Record the result
plt.savefig('test.png', edgecolor='none', transparent=False, bbox_inches='tight')
plt.tight_layout()
if __name__ == '__main__':
data_entries = 30000 # Large values here cause issues
num_of_channels = 255
fields_to_plot = ('random', 'periodic')
data = GetData(data_entries, num_of_channels)
width_in_pixels = len(data.index)+200
additional_vertical_space_per_plot = 50
num_of_plots = len(fields_to_plot)
height_in_pixels = (num_of_channels+additional_vertical_space_per_plot)*num_of_plots
dpi = 80 # The default according to the documentation.
fig = plt.figure(1,figsize=(width_in_pixels/dpi, height_in_pixels/dpi), dpi=dpi)
PlotData(data, num_of_channels, fields_to_plot)
The secret sauce in the end was:
Add plt.tight_layout() immediately before the plt.savefig call.
Add bbox_inches='tight' to the plt.savefig call.
Add , pad=colorbar_padding after calculating colorbar_padding by checking what proportion of the overall image width a 20 pixel padding equates to.

matplotlib - clip image using line(s)

Is it possible to clip an image generated by imshow() to the area under a line/multiple lines? I think Clip an image using several patches in matplotlib may have the solution, but I'm not sure how to apply it here.
I just want the coloring (from imshow()) under the lines in this plot:
Here is my plotting code:
from __future__ import division
from matplotlib.pyplot import *
from numpy import *
# wavelength array
lambd = logspace(-3.8, -7.2, 1000)
# temperatures
T_earth = 300
T_sun = 6000
# planck's law constants
h = 6.626069e-34
c = 2.997925e8
k = 1.380648e-23
# compute power using planck's law
power_earth = 2*h*c**2/lambd**5 * 1/(exp(h*c/(lambd*k*T_earth)) - 1)
power_sun = 2*h*c**2/lambd**5 * 1/(exp(h*c/(lambd*k*T_sun)) - 1)
# set up color array based on "spectrum" colormap
colors = zeros((1000,1000))
colors[:,:1000-764] = 0.03
for x,i in enumerate(range(701,765)):
colors[:,1000-i] = 1-x/(765-701)
colors[:,1000-701:] = 0.98
figure(1,(4,3),dpi=100)
# plot normalized planck's law graphs
semilogx(lambd, power_earth/max(power_earth), 'b-', lw=4, zorder=5); hold(True)
semilogx(lambd, power_sun/max(power_sun), 'r-', lw=4, zorder=5); hold(True)
# remove ticks (for now)
yticks([]); xticks([])
# set axis to contain lines nicely
axis([min(lambd), max(lambd), 0, 1.1])
# plot colors, shift extent to match graph
imshow(colors, cmap="spectral", extent=[min(lambd), max(lambd), 0, 1.1])
# reverse x-axis (longer wavelengths to the left)
ax = gca(); ax.set_xlim(ax.get_xlim()[::-1])
tight_layout()
show()
What you can do in this case is using the area under the curve as a Patch to apply set_clip_path. All you have to do is call fill_between and extract the corresponding path, like this:
semilogx(lambd, power_earth/max(power_earth), 'b-', lw=4, zorder=5)
# Area under the curve
fillb_earth = fill_between(lambd, power_earth/max(power_earth), color='none', lw=0)
# Get the path
path_earth, = fillb_earth.get_paths()
# Create a Patch
mask_earth = PathPatch(path_earth, fc='none')
# Add it to the current axes
gca().add_patch(mask_earth)
# Add the image
im_earth = imshow(colors, cmap="spectral", extent=[min(lambd), max(lambd), 0, 1.1])
# Clip the image with the Patch
im_earth.set_clip_path(mask_earth)
And then repeat the same lines for the Sun. Here is the result.

Categories

Resources