3D plots and legends issue when plotting some dimensions of PCA - python

I have the following data and labels I am transforming through PCA.
The labels are only 0 or 1.
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import numpy as np
fields = ["Occupancy", "Temperature", "Humidity", "Light", "CO2", "HumidityRatio", "NSM", "WeekStatus"]
df = pd.read_csv('datatraining-updated.csv', skipinitialspace=True, usecols=fields, sep=',')
#Get the output from pandas as a numpy matrix
final_data=df.values
#Data
X = final_data[:,1:8]
#Labels
y = final_data[:,0]
#Normalize features
X_scaled = StandardScaler().fit_transform(X)
#Apply PCA on them
pca = PCA(n_components=7).fit(X_scaled)
#Transform them with PCA
X_reduced = pca.transform(X_scaled)
Then, I just want to show, in a 3D graph, the 3 PCA features with highest variance, I can find them as follows
#Show variable importance
importance = pca.explained_variance_ratio_
print('Explained variation per principal component:
{}'.format(importance))
After that, I want to plot only the top-3 highest variance features. So, I previously select them in the code below
X_reduced=X_reduced[:, [0, 4, 5]]
Ok, here is my problem: I can plot them without the legend. When I try to plot them using the following code
# Create plot
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax = fig.gca(projection='3d')
colors = ("red", "gray")
for data, color, group in zip(X_reduced, colors, y):
dim1,dim2,dim3=data
ax.scatter(dim1, dim2, dim3, c=color, edgecolors='none',
label=group)
plt.title('Matplot 3d scatter plot')
plt.legend(y)
plt.show()
I get the following error:
plot_data-3d-pca.py:56: UserWarning: Requested projection is different from current axis projection, creating new axis with requested projection.
ax = fig.gca(projection='3d')
plot_data-3d-pca.py:56: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance. In a future version, a new instance will always be created and returned. Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.
ax = fig.gca(projection='3d')
Traceback (most recent call last):
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/backends/backend_gtk3.py", line 307, in idle_draw
self.draw()
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/backends/backend_gtk3agg.py", line 76, in draw
self._render_figure(allocation.width, allocation.height)
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/backends/backend_gtk3agg.py", line 20, in _render_figure
backend_agg.FigureCanvasAgg.draw(self)
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/backends/backend_agg.py", line 388, in draw
self.figure.draw(self.renderer)
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/artist.py", line 38, in draw_wrapper
return draw(artist, renderer, *args, **kwargs)
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/figure.py", line 1709, in draw
renderer, self, artists, self.suppressComposite)
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/image.py", line 135, in _draw_list_compositing_images
a.draw(renderer)
File "/home/unica-server/.local/lib/python3.6/site-packages/matplotlib/artist.py", line 38, in draw_wrapper
return draw(artist, renderer, *args, **kwargs)
File "/home/unica-server/.local/lib/python3.6/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 292, in draw
reverse=True)):
File "/home/unica-server/.local/lib/python3.6/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 291, in <lambda>
key=lambda col: col.do_3d_projection(renderer),
File "/home/unica-server/.local/lib/python3.6/site-packages/mpl_toolkits/mplot3d/art3d.py", line 545, in do_3d_projection
ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
File "/home/unica-server/.local/lib/python3.6/site-packages/mpl_toolkits/mplot3d/art3d.py", line 847, in _zalpha
rgba = np.broadcast_to(mcolors.to_rgba_array(colors), (len(zs), 4))
File "<__array_function__ internals>", line 6, in broadcast_to
File "/home/unica-server/.local/lib/python3.6/site-packages/numpy/lib/stride_tricks.py", line 182, in broadcast_to
return _broadcast_to(array, shape, subok=subok, readonly=True)
File "/home/unica-server/.local/lib/python3.6/site-packages/numpy/lib/stride_tricks.py", line 127, in _broadcast_to
op_flags=['readonly'], itershape=shape, order='C')
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (0,4) and requested shape (1,4)
My y's shape is (8143,) and my X_reduced's shape is (8143,3)
What is my mistake?
EDIT: The data I am using can be found here

The first warning Requested projection is different from current axis projection
is because you are trying to change the projection of an axis after its creation with ax = fig.gca(projection='3d') but you cannot. Set the projection at creation instead.
To fix the second error, replace edgecolors='none' by edgecolors=None.
The following corrected code works for me.
# Create plot
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d') # set projection at creation of axis
# ax = fig.gca(projection='3d') # you cannot change the projection after creation
colors = ("red", "gray")
for data, color, group in zip(X_reduced, colors, y):
dim1,dim2,dim3=data
# replace 'none' by None
ax.scatter(dim1, dim2, dim3, c=color, edgecolors=None, label=group)
plt.title('Matplot 3d scatter plot')
plt.legend(y)
plt.show()
EDIT : Above is my answer to what I understood of the original question. Below is a looped version of mad's own answer.
class_values = [0, 1]
labels = ['Empty', 'Full']
n_class = len(class_values)
# allocate lists
index_class = [None] * n_class
X_reduced_class = [None] * n_class
for i, class_i in enumerate(class_values) :
# get where are the 0s and 1s labels
index_class[i] = np.where(np.isin(y, class_i))
# get reduced PCA for each label
X_reduced_class[i] = X_reduced[index_class[i]]
colors = ['blue', 'red']
# To getter a better understanding of interaction of the dimensions
# plot the first three PCA dimensions
fig = plt.figure(1, figsize=(8, 6))
ax = Axes3D(fig, elev=-150, azim=110)
ids_plot = [0, 4, 5]
for i in range(n_class) :
# get the three interesting columns
data = X_reduced_class[i][:, ids_plot]
ax.scatter(data[:,0], data[:,1], data[:,2], c=colors[i], edgecolor='k', s=40, label=labels[i])
ax.set_title("Data Visualization with 3 highest variance dimensions with PCA")
ax.set_xlabel("1st eigenvector")
ax.w_xaxis.set_ticklabels([])
ax.set_ylabel("2nd eigenvector")
ax.w_yaxis.set_ticklabels([])
ax.set_zlabel("3rd eigenvector")
ax.w_zaxis.set_ticklabels([])
ax.legend()
plt.show()

I solved the error in a different way.
I did not know that, for each label, I had to do a different scatterplot. Thanks to this post I found the answer.
My solution was first to separate the labels and data from one class, and then do the same for the other class. Finally, I plot them separately with different scatterplots. So, firstly I identify the different labels (I have only two labels, 0 or 1) and their data (their corresponding Xs).
#Get where are the 0s and 1s labels
index_class1 = np.where(np.isin(y, 0))
index_class2 = np.where(np.isin(y, 1))
#Get reduced PCA for each label
X_reducedclass1=X_reduced[index_class1][:]
X_reducedclass2=X_reduced[index_class2][:]
Then, I will plot each PCA reduced vectors from each class in different scatterplots
colors = ['blue', 'red']
# To getter a better understanding of interaction of the dimensions
# plot the first three PCA dimensions
fig = plt.figure(1, figsize=(8, 6))
ax = Axes3D(fig, elev=-150, azim=110)
scatter1=ax.scatter(X_reducedclass1[:, 0], X_reducedclass1[:, 4], X_reducedclass1[:, 5], c=colors[0], cmap=plt.cm.Set1, edgecolor='k', s=40)
scatter2=ax.scatter(X_reducedclass2[:, 0], X_reducedclass2[:, 4], X_reducedclass2[:, 5], c=colors[1], cmap=plt.cm.Set1, edgecolor='k', s=40)
ax.set_title("Data Visualization with 3 highest variance dimensions with PCA")
ax.set_xlabel("1st eigenvector")
ax.w_xaxis.set_ticklabels([])
ax.set_ylabel("2nd eigenvector")
ax.w_yaxis.set_ticklabels([])
ax.set_zlabel("3rd eigenvector")
ax.w_zaxis.set_ticklabels([])
#ax.legend(np.unique(y))
ax.legend([scatter1, scatter2], ['Empty', 'Full'], loc="upper right")
plt.show()
Which gives me this beautiful image
Of course, such a code can be simplified with a for loop too (altough I have no idea how to do that).

Related

Python - Unable to use plot_trisurf to plot a 2D array in Matplotlib

I have a 2D array which I am trying to plot using plot_trisurf and I can't seem to make it work no matter what I try.
Here follows a minimally reproducible example where I am able to use plot_surface
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
frames = []
for _ in range(5):
matrix = np.array([
[np.random.normal(20, 3) for _ in range(6)]
for _ in range(6)
],
dtype = np.float32
)
frames.append(matrix)
fig= plt.figure(figsize = [6, 5], dpi = 100)
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.set_zlim(10, 30)
data = frames[0]
X, Y = np.meshgrid(
range(len(frames[0])),
range(len(frames[0][0]))
)
ax.plot_surface(X, Y, data, cmap=cm.Spectral, linewidth=1)
plt.show()
yielding
But when I try
ax.plot_trisurf(X, Y, data, cmap=cm.Spectral, linewidth=1)
I get an error that X and Y must be equal-length 1D arrays, so then I tried
ax.plot_trisurf(X.flatten(), Y.flatten(), data, cmap=cm.Spectral, linewidth=1)
But then I get the stacktrace
Command failed: python3 main.py
Traceback (most recent call last):
File "main.py", line 25, in
ax.plot_trisurf(X.flatten(), Y.flatten(), data, cmap=cm.Spectral, linewidth=1)
File "/usr/local/lib64/python3.8/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 2051, in plot_trisurf
zt = z[triangles]
IndexError: index 6 is out of bounds for axis 0 with size 6
Does anyone know what I am doing wrong? I have tried many things and looked at the documentation, but can't seem to understand what I am missing.

How do I colour a line chart based on Log10 of a column in my data?

According to a problem statement I have, I need to create a line chart from a dataset with 3 columns. But the colour needs to be log10(col_3). I'm attempting to do this via matplotlib:
data.plot(x = 'col_1', y = 'col_2')
I have no idea how I could convert a column to colormap values using log10.
Edit: Based on the suggestion by #JohanC in the comments,
This example shows how multicolored line plots can be created with matplotlib.
I tried out the following code:-
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
x = [20160101, 20160102, 20160103, 20160104, 20160105, 20160106, 20160107, 20160108]
y = [66.2, 71.76, 68.89, 64.63, 63.89, 67.09, 65.67, 65.42]
z = np.log10([710, 850, 900, 820, 720, 790, 670, 1070])
frame = pd.DataFrame({'col_1' : x,
'col_2' : y,
'col_3' : z})
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis = 1)
fig, axs = plt.subplots(2, 1, sharex = True, sharey = True)
norm = plt.Normalize(z.min(), z.max())
lc = LineCollection(segments, cmap = 'viridis', norm = norm)
lc.set_array(z)
line = axs[0].add_collection(lc)
fig.colorbar(line, ax = axs[0])
plt.show()
I still get nothing but a blank graph, with a colourbar. If I replace the x,y,z values with what's provided in the matplotlib documentation example, I get the graph to show.
I'm really failing to understand what I'm doing wrong.

Matplotilb - How to set colorbar for line plot with log scale

I'm having a problem adding a colorbar to a plot of many lines corresponding to a power-law.
To create the color-bar for a non-image plot, I added a dummy plot (from answers here: Matplotlib - add colorbar to a sequence of line plots).
To colorbar ticks do not correspond to the colors of the plot.
I have tried changing the norm of the colorbar, and I can fine-tune it to be semy accurate for a particular case, but I can't do that generally.
def plot_loglog_gauss():
from matplotlib import cm as color_map
import matplotlib as mpl
"""Creating the data"""
time_vector = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256]
amplitudes = [t ** 2 * np.exp(-t * np.power(np.linspace(-0.5, 0.5, 100), 2)) for t in time_vector]
"""Getting the non-zero minimum of the data"""
data = np.concatenate(amplitudes).ravel()
data_min = np.min(data[np.nonzero(data)])
"""Creating K-space data"""
k_vector = np.linspace(0,1,100)
"""Plotting"""
number_of_plots = len(time_vector)
color_map_name = 'jet'
my_map = color_map.get_cmap(color_map_name)
colors = my_map(np.linspace(0, 1, number_of_plots, endpoint=True))
# plt.figure()
# dummy_plot = plt.contourf([[0, 0], [0, 0]], time_vector, cmap=my_map)
# plt.clf()
norm = mpl.colors.Normalize(vmin=time_vector[0], vmax=time_vector[-1])
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=color_map_name)
cmap.set_array([])
for i in range(number_of_plots):
plt.plot(k_vector, amplitudes[i], color=colors[i], label=time_vector[i])
c = np.arange(1, number_of_plots + 1)
plt.xlabel('Frequency')
plt.ylabel('Amplitude')
plt.yscale('symlog', linthreshy=data_min)
plt.xscale('log')
plt.legend(loc=3)
ticks = time_vector
plt.colorbar(cmap, ticks=ticks, shrink=1.0, fraction=0.1, pad=0)
plt.show()
By comparing with the legend you see the ticks values don't match the actual colors. For example, 128 is shown in green in the colormap while red in the legend.
The actual result should be a linear-color colorbar. with ticks at regular intervals on the colorbar (corresponding to irregular time intervals...). And of course correct color for value of tick.
(Eventually the plot contains many plots (len(time_vector) ~ 100), I lowered the number of plots to illustrate and to be able to show the legend.)
To clarify, this is what I want the result to look like.
The most important principle is to keep the colors from the line plots and the ScalarMappable in sync. This means, the color of the line should not be taken from an independent list of colors, but rather from the same colormap and using the same normalization as the colorbar to be shown.
One major problem is then to decide what to do with 0 which cannot be part of a loagrithmic normalization. The following is a workaround assuming a linear scale between 0 and 2, and a log scale above, using a SymLogNorm.
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
"""Creating the data"""
time_vector = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256]
amplitudes = [t ** 2 * np.exp(-t * np.power(np.linspace(-0.5, 0.5, 100), 2)) for t in time_vector]
"""Getting the non-zero minimum of the data"""
data = np.concatenate(amplitudes).ravel()
data_min = np.min(data[np.nonzero(data)])
"""Creating K-space data"""
k_vector = np.linspace(0,1,100)
"""Plotting"""
cmap = plt.cm.get_cmap("jet")
norm = mpl.colors.SymLogNorm(2, vmin=time_vector[0], vmax=time_vector[-1])
sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
for i in range(len(time_vector)):
plt.plot(k_vector, amplitudes[i], color=cmap(norm(time_vector[i])), label=time_vector[i])
#c = np.arange(1, number_of_plots + 1)
plt.xlabel('Frequency')
plt.ylabel('Amplitude')
plt.yscale('symlog', linthreshy=data_min)
plt.xscale('log')
plt.legend(loc=3)
cbar = plt.colorbar(sm, ticks=time_vector, format=mpl.ticker.ScalarFormatter(),
shrink=1.0, fraction=0.1, pad=0)
plt.show()

Plot data from large sparse numpy.ndarray in 3D bar plot with different array length

My question is similar to this question which I want to plot my spectroscopic data to 3D-plot but
1) My data is matrix in np.ndarray
2) It has a large dimension of 1201*5001 (result.shape = (1201,5001)), therefore the hard code labeling manually is not suitable.
3) The data is not continuous and sparse. The final plot may look like mplot3d bar3d.
Can I use 3D bar plot from Matplotlib in this case? If possible, how to define the different length for every axes?
This is my code in progress (3rd update)
if __name__ == '__main__':
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from array, x is time, y is mz, z is intensity
# in graph x is mz, y is time, z is intensity
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
zs = np.arange(0, 50.01, 0.01)
for z in zs:
xs = np.arange(300, 1500.01, 1)
ys = result
ax.bar(xs,ys,zs=z,zdir='y')
plt.show()
Error (3rd)
Traceback (most recent call last):
File "prelimnmf_importcsv3.py", line 70, in <module>
ax.bar(xs,ys,zs=z,zdir='y')
File "/Users/pp/anaconda/lib/python2.7/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 2394, in bar
patches = Axes.bar(self, left, height, *args, **kwargs)
File "/Users/pp/anaconda/lib/python2.7/site-packages/matplotlib/__init__.py", line 1892, in inner
return func(ax, *args, **kwargs)
File "/Users/pp/anaconda/lib/python2.7/site-packages/matplotlib/axes/_axes.py", line 2115, in bar
if h < 0:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
While I doubt that a bar plot with 1200*5000 can give any visual insight into the data, it should still be possible to use it.
So here is an example
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
#Assume you have arrays like this
x = np.arange(300,1500,100)
y = np.arange(4)*10
Z = np.random.rand(len(y), len(x))*33
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in range(len(y))[::-1]:
c = plt.cm.jet(i/float(len(y)))
ax.bar(x, Z[i,:], zs=y[i], zdir='y', width=80,alpha=1 )
ax.set_xlabel('time')
ax.set_ylabel('mz')
ax.set_zlabel('intensity')
plt.show()

error plotting with slider (python matplotlib)

I search on internet how using a slider with 3D data and I find this algorithm which plot 3D data in 2D with a slider, so I copy-paste it and I tried to run it in order to adapt it (for solving my real problem : plotting 3D+time data and using a slider to interact with the time).
This is my complete code :
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
import scipy.ndimage as ndi
data = np.zeros((10, 10, 10))
data[5, 5, 5] = 10.
data = ndi.filters.gaussian_filter(data, sigma=1)
print(data.max())
def cube_show_slider(cube, axis=0, **kwargs):
"""
Display a 3d ndarray with a slider to move along the third dimension.
Extra keyword arguments are passed to imshow
"""
# check dim
if not cube.ndim == 3:
raise ValueError("cube should be an ndarray with ndim == 3")
# generate figure
fig = plt.figure()
ax = plt.subplot(111)
fig.subplots_adjust(left=0.25, bottom=0.25)
# select first image
s = [slice(0, 1) if i == axis else slice(None) for i in range(3)]
im = cube[s].squeeze()
# display image
l = ax.matshow(im, **kwargs)
cb = plt.colorbar(l)
cb.set_clim(vmin=data.min(), vmax=data.max())
cb.draw_all()
# define slider
axcolor = 'lightgoldenrodyellow'
ax = fig.add_axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor)
slideryo = Slider(ax, 'Axis %i index' % axis, 0, cube.shape[axis] - 1, valinit=0, valfmt='%i')
slideryo.on_changed(update)
plt.show()
def update(val):
ind = int(slider.val)
s = [slice(ind, ind + 1) if i == axis else slice(None) for i in range(3)]
im = cube[s].squeeze()
l.set_data(im, **kwargs)
cb.set_clim(vmin=data.min(), vmax=data.max())
cb.formatter.set_powerlimits((0, 0))
cb.update_ticks()
cb.draw_all()
fig.canvas.draw()
cube_show_slider(data)
A window with the axis and the slider are on my screen but no data is plotted. The plot is just a big blue square and when I interact with the slider I have this error :
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/backend_bases.py", line 1952, in motion_notify_event
self.callbacks.process(s, event)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/cbook.py", line 563, in process
proxy(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/cbook.py", line 430, in __call__
return mtd(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/widgets.py", line 434, in _update
self.set_val(val)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/widgets.py", line 448, in set_val
func(val)
File "<stdin>", line 2, in update
NameError: global name 'slider' is not defined
I don't understand why it doesn't work. All the functions and files that the console cite were added by the importation. And I know that the code written by mmensing is ok, so I missed something but what? I'm sure that I did a stupid error, but I don't know where.
To check if the data I created are ok, I write this code to see the 3d plot in 3D without slider :
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as ndi
mpl.rcParams['legend.fontsize'] = 10
fig = plt.figure()
ax = fig.gca(projection='3d')
data = np.zeros((10, 10, 10))
data[5, 5, 5] = 10.
data = ndi.filters.gaussian_filter(data, sigma=1)
ax.plot(data[0,:,:], data[1,:,:], data[2,:,:], label='my data')
ax.legend()
plt.show()
But it returns this error :
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 1541, in plot
lines = Axes.plot(self, xs, ys, *args[argsi:], **kwargs)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/__init__.py", line 1812, in inner
return func(ax, *args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/axes/_axes.py", line 1424, in plot
for line in self._get_lines(*args, **kwargs):
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/axes/_base.py", line 386, in _grab_next_args
for seg in self._plot_args(remaining, kwargs):
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/axes/_base.py", line 339, in _plot_args
raise ValueError('third arg must be a format string')
ValueError: third arg must be a format string
/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/matplotlib/axes/_axes.py:519: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.
warnings.warn("No labelled objects found. ")
What can I do ?
I have corrected your code, you had some errors which you can find by comparing:
update function needs to be defined in subprogram so that it is accessible there, indentation wrong
your slider had two different names at different positions
refer to update function only after having defined the function.
Hope it works now.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
import scipy.ndimage as ndi
data = np.zeros((10, 10, 10))
data[5, 5, 5] = 10.
data = ndi.filters.gaussian_filter(data, sigma=1)
print(data.max())
print data.shape
def cube_show_slider(cube, axis=0, **kwargs):
"""
Display a 3d ndarray with a slider to move along the third dimension.
Extra keyword arguments are passed to imshow
"""
# check dim
if not cube.ndim == 3:
raise ValueError("cube should be an ndarray with ndim == 3")
# generate figure
fig = plt.figure()
ax = plt.subplot(111)
fig.subplots_adjust(left=0.25, bottom=0.25)
# select first image
s = [slice(0, 1) if i == axis else slice(None) for i in range(3)]
im = cube[s].squeeze()
# display image
l = ax.matshow(im, **kwargs)
cb = plt.colorbar(l)
cb.set_clim(vmin=data.min(), vmax=data.max())
cb.draw_all()
# define slider
axcolor = 'lightgoldenrodyellow'
ax = fig.add_axes([0.25, 0.1, 0.65, 0.03], axisbg=axcolor)
slideryo = Slider(ax, 'Axis %i index' % axis, 0, cube.shape[axis] - 1, valinit=0, valfmt='%i')
def update(val):
ind = int(slideryo.val)
s = [slice(ind, ind + 1) if i == axis else slice(None) for i in range(3)]
im = cube[s].squeeze()
l.set_data(im, **kwargs)
cb.set_clim(vmin=data.min(), vmax=data.max())
cb.formatter.set_powerlimits((0, 0))
cb.update_ticks()
cb.draw_all()
fig.canvas.draw()
slideryo.on_changed(update)
plt.show()
cube_show_slider(data)

Categories

Resources