Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use subfigures and subplots for layout #67

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 68 additions & 91 deletions mpl_animators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None,
self.button_labels = button_labels or []
self.num_buttons = len(self.button_func)

if not fig:
fig = plt.figure()
self.fig = fig

self.data = data
self.interval = interval
self.if_colorbar = colorbar
self.imshow_kwargs = kwargs

if len(slider_functions) != len(slider_ranges):
raise ValueError("slider_functions and slider_ranges must be the same length.")

Expand All @@ -111,6 +102,18 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None,
self.slider_ranges = slider_ranges
self.slider_labels = slider_labels or [''] * len(slider_functions)

if not fig:
fig = plt.figure(layout="constrained")
self.parent_fig = fig
self.subfigs = None
self._setup_subfigure_grid()
self.fig = self.subfigs[0]

self.data = data
self.interval = interval
self.if_colorbar = colorbar
self.imshow_kwargs = kwargs

# Set active slider
self.active_slider = 0

Expand All @@ -127,8 +130,8 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None,
#
# Only do this if figure has a manager, so directly constructed figures
# (ie. via matplotlib.figure.Figure()) work.
if hasattr(self.fig.canvas, "manager") and self.fig.canvas.manager is not None:
plt.sca(self.axes)
# if isinstance(self.axes, mpl.axes.Axes) and hasattr(self.fig.canvas, "manager") and self.fig.canvas.manager is not None:
# plt.sca(self.axes)

# Do Plot
self.im = self.plot_start_image(self.axes)
Expand Down Expand Up @@ -266,108 +269,82 @@ def _dehighlight_slider(self, ind):
# =============================================================================
# Build the figure and place the widgets
# =============================================================================
def _setup_subfigure_grid(self):
slider_height = 0.075 * self.num_sliders
button_height = 0.075 if self.num_buttons else 0
main_figure_height = 1.0 - slider_height - button_height

self.subfigs = self.parent_fig.subfigures(
nrows=3,
height_ratios=[main_figure_height, button_height, slider_height]
)

def _setup_main_axes(self):
"""
Allow replacement of main axes by subclassing.
This method must set the ``axes`` attribute.
"""
if self.axes is None:
self.axes = self.fig.add_subplot(111)
self.axes = self.subfigs[0].add_subplot(111)

def _make_axes_grid(self):
self._setup_main_axes()

# Split up the current axes so there is space for start & stop buttons
self.divider = make_axes_locatable(self.axes)
pad = 0.01 # Padding between axes
pad_size = Size.Fraction(pad, Size.AxesX(self.axes))
large_pad_size = Size.Fraction(0.1, Size.AxesY(self.axes))

button_grid = max((7, self.num_buttons))

# Define size of useful axes cells, 50% each in x 20% for buttons in y.
ysize = Size.Fraction((1.-2.*pad)/15., Size.AxesY(self.axes))
xsize = Size.Fraction((1.-2.*pad)/button_grid, Size.AxesX(self.axes))

# Set up grid, 3x3 with cells for padding.
if self.num_buttons > 0:
horiz = [xsize] + [pad_size, xsize]*(button_grid-1)
vert = [ysize, pad_size] * self.num_sliders + \
[large_pad_size, large_pad_size, Size.AxesY(self.axes)]
else:
vert = [ysize, large_pad_size] * self.num_sliders + \
[large_pad_size, Size.AxesY(self.axes)]
horiz = [Size.Fraction(0.1, Size.AxesX(self.axes))] + \
[Size.Fraction(0.05, Size.AxesX(self.axes))] + \
[Size.Fraction(0.65, Size.AxesX(self.axes))] + \
[Size.Fraction(0.1, Size.AxesX(self.axes))] + \
[Size.Fraction(0.1, Size.AxesX(self.axes))]

self.divider.set_horizontal(horiz)
self.divider.set_vertical(vert)
self.button_ny = len(vert) - 3

# If we are going to add a colorbar it'll need an axis next to the plot
if self.if_colorbar:
nx1 = -3
self.cax = self.fig.add_axes((0., 0., 0.141, 1.))
locator = self.divider.new_locator(nx=-2, ny=len(vert)-1, nx1=-1)
self.cax.set_axes_locator(locator)
else:
# Main figure spans all horiz and is in the top (2) in vert.
nx1 = -1

self.axes.set_axes_locator(
self.divider.new_locator(nx=0, ny=len(vert)-1, nx1=nx1))
self.cax = self.axes.inset_axes([1.05, 0.0, 0.05, 1.0])

def _add_widgets(self):
self.buttons = []
for i in range(0, self.num_buttons):
x = i * 2
# The i+1/10. is a bug that if you make two axes directly on top of
# one another then the divider doesn't work.
self.buttons.append(self.fig.add_axes((0., 0., 0.+i/10., 1.)))
locator = self.divider.new_locator(nx=x, ny=self.button_ny)
self.buttons[-1].set_axes_locator(locator)
self.buttons[-1]._button = widgets.Button(self.buttons[-1],
self.button_labels[i])
self.buttons[-1]._button.on_clicked(partial(self.button_func[i], self))

self.sliders = []
self.slider_buttons = []
for i in range(self.num_sliders):
y = i * 2
self.sliders.append(self.fig.add_axes((0., 0., 0.01+i/10., 1.)))
if self.num_buttons == 0:
nx1 = 3
else:
nx1 = -2
locator = self.divider.new_locator(nx=2, ny=y, nx1=nx1)
self.sliders[-1].set_axes_locator(locator)
self.sliders[-1].text(0.5, 0.5, self.slider_labels[i],
transform=self.sliders[-1].transAxes,
horizontalalignment="center",
verticalalignment="center")

sframe = widgets.Slider(self.sliders[-1], "",
# Add the custom button row
if self.num_buttons:
self.buttons = self.subfigs[1].subplots(ncols=self.num_buttons)
for i, bax in enumerate(self.buttons):
bax._button = widgets.Button(bax,
self.button_labels[i])
bax._button.on_clicked(partial(self.button_func[i], self))

controls_axes = self.subfigs[2].subplots(
ncols=2,
nrows=self.num_sliders,
width_ratios=[0.1, 0.9],
squeeze=False,
)

self.slider_buttons = controls_axes[:, 0].tolist()
self.sliders = controls_axes[:, 1].tolist()

for i, saxis in enumerate(self.sliders):
sframe = widgets.Slider(saxis,
"", # We add label manually
self.slider_ranges[i][0],
self.slider_ranges[i][-1]-1,
valinit=self.slider_ranges[i][0],
valfmt='%4.1f')
sframe.on_changed(partial(self._slider_changed, slider=sframe))
sframe.slider_ind = i
sframe.cval = sframe.val
self.sliders[-1]._slider = sframe

self.slider_buttons.append(
self.fig.add_axes((0., 0., 0.05+y/10., 1.)))
locator = self.divider.new_locator(nx=0, ny=y)

self.slider_buttons[-1].set_axes_locator(locator)
butt = widgets.Button(self.slider_buttons[-1], ">")
butt.on_clicked(partial(self._click_slider_button, button=butt, slider=sframe))
saxis._slider = sframe

# Add the label as text in the middle of the axis
saxis.text(
0.5,
0.5,
self.slider_labels[i],
transform=saxis.transAxes,
horizontalalignment="center",
verticalalignment="center",
fontsize="x-small",
)

for i, sbaxis in enumerate(self.slider_buttons):
butt = widgets.Button(sbaxis, ">")
butt.on_clicked(partial(
self._click_slider_button,
button=butt,
slider=self.sliders[i]._slider
))
butt.clicked = False
self.slider_buttons[-1]._button = butt
sbaxis._button = butt

# =============================================================================
# Widget callbacks
Expand Down
2 changes: 1 addition & 1 deletion mpl_animators/tests/test_basefuncanimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,4 @@ def test_lineanimator_figure():
xdata = np.tile(np.linspace(
0, 100, (data_shape0[plot_axis0] + 1)), (data_shape0[slider_axis0], 1))
ani = LineAnimator(data0, plot_axis_index=plot_axis0, axis_ranges=[None, xdata])
return ani.fig
return ani.parent_fig
26 changes: 13 additions & 13 deletions mpl_animators/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def test_constructor_errors(wcs_4d):
def test_array_animator_wcs_2d_simple_plot(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'])
return a.fig
return a.parent_fig


@figure_test
def test_array_animator_wcs_2d_clip_interval(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], clip_interval=(1, 99)*u.percent)
return a.fig
return a.parent_fig


def test_array_animator_wcs_2d_clip_interval_change(wcs_4d):
Expand All @@ -158,7 +158,7 @@ def test_array_animator_wcs_2d_clip_interval_change(wcs_4d):
def test_array_animator_wcs_2d_celestial_sliders(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, ['x', 'y', 0, 0])
return a.fig
return a.parent_fig


def test_to_axes(wcs_4d):
Expand All @@ -172,15 +172,15 @@ def test_array_animator_wcs_2d_update_plot(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'])
a.update_plot(1, a.im, a.sliders[0]._slider)
return a.fig
return a.parent_fig


@figure_test
def test_array_animator_wcs_2d_transpose_update_plot(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True)
a.update_plot(1, a.im, a.sliders[0]._slider)
return a.fig
return a.parent_fig


@figure_test
Expand All @@ -191,7 +191,7 @@ def test_array_animator_wcs_2d_colorbar_buttons(wcs_4d):
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'],
colorbar=True, button_func=bf, button_labels=bl)
a.update_plot(1, a.im, a.sliders[0]._slider)
return a.fig
return a.parent_fig


@figure_test
Expand All @@ -200,7 +200,7 @@ def test_array_animator_wcs_2d_colorbar_buttons_default_labels(wcs_4d):
bf = [lambda x: x] * 10
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True, button_func=bf)
a.update_plot(1, a.im, a.sliders[0]._slider)
return a.fig
return a.parent_fig


@figure_test
Expand All @@ -216,15 +216,15 @@ def vmax_slider(val, im, slider):
slider_functions=[vmin_slider, vmax_slider],
slider_ranges=[[0, 100], [0, 100]])
a.update_plot(1, a.im, a.sliders[0]._slider)
return a.fig
return a.parent_fig


@figure_test
def test_array_animator_wcs_1d_update_plot(wcs_4d):
data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 0], ylabel="Y axis!")
a.sliders[0]._slider.set_val(1)
return a.fig
return a.parent_fig


@figure_test
Expand All @@ -244,7 +244,7 @@ def test_array_animator_wcs_1d_update_plot_masked(wcs_3d):
a = ArrayAnimatorWCS(data, wcs_3d, ['x', 0, 0], ylabel="Y axis!")
a.sliders[0]._slider.set_val(wcs_3d.array_shape[0] / 2)

return a.fig
return a.parent_fig


@figure_test
Expand All @@ -261,7 +261,7 @@ def test_array_animator_wcs_coord_params(wcs_4d):

data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params)
return a.fig
return a.parent_fig


@figure_test
Expand All @@ -278,7 +278,7 @@ def test_array_animator_wcs_coord_params_no_ticks(wcs_4d):

data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params)
return a.fig
return a.parent_fig


@figure_test
Expand All @@ -295,4 +295,4 @@ def test_array_animator_wcs_coord_params_grid(wcs_4d):

data = np.arange(120).reshape((5, 4, 3, 2))
a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params)
return a.fig
return a.parent_fig
Loading