From 42a3e519d05f3946ad86e1c1a6dbd430324652b3 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Sun, 10 Nov 2024 15:28:18 +0000 Subject: [PATCH 1/2] Use subfigures and subplots for layout --- mpl_animators/base.py | 140 ++++++++----------- mpl_animators/tests/test_basefuncanimator.py | 2 +- mpl_animators/tests/test_wcs.py | 26 ++-- 3 files changed, 72 insertions(+), 96 deletions(-) diff --git a/mpl_animators/base.py b/mpl_animators/base.py index 15fe0ee..cac8d55 100644 --- a/mpl_animators/base.py +++ b/mpl_animators/base.py @@ -91,8 +91,11 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None, self.num_buttons = len(self.button_func) if not fig: - fig = plt.figure() - self.fig = fig + fig = plt.figure(layout="constrained") + self.parent_fig = fig + self.subfigs = None + self._setup_subfigure_grid() + self.fig = self.subfigs[0] self.data = data self.interval = interval @@ -266,90 +269,53 @@ def _dehighlight_slider(self, ind): # ============================================================================= # Build the figure and place the widgets # ============================================================================= + def _setup_subfigure_grid(self): + slider_height = 0.075 * self.num_sliders + button_height = 0.075 if self.num_buttons else 0 + main_figure_height = 1.0 - slider_height - button_height + + self.subfigs = self.parent_fig.subfigures( + nrows=3, + height_ratios=[main_figure_height, button_height, slider_height] + ) + def _setup_main_axes(self): """ Allow replacement of main axes by subclassing. This method must set the ``axes`` attribute. """ if self.axes is None: - self.axes = self.fig.add_subplot(111) + self.axes = self.subfigs[0].add_subplot(111) def _make_axes_grid(self): self._setup_main_axes() - # Split up the current axes so there is space for start & stop buttons - self.divider = make_axes_locatable(self.axes) - pad = 0.01 # Padding between axes - pad_size = Size.Fraction(pad, Size.AxesX(self.axes)) - large_pad_size = Size.Fraction(0.1, Size.AxesY(self.axes)) - - button_grid = max((7, self.num_buttons)) - - # Define size of useful axes cells, 50% each in x 20% for buttons in y. - ysize = Size.Fraction((1.-2.*pad)/15., Size.AxesY(self.axes)) - xsize = Size.Fraction((1.-2.*pad)/button_grid, Size.AxesX(self.axes)) - - # Set up grid, 3x3 with cells for padding. - if self.num_buttons > 0: - horiz = [xsize] + [pad_size, xsize]*(button_grid-1) - vert = [ysize, pad_size] * self.num_sliders + \ - [large_pad_size, large_pad_size, Size.AxesY(self.axes)] - else: - vert = [ysize, large_pad_size] * self.num_sliders + \ - [large_pad_size, Size.AxesY(self.axes)] - horiz = [Size.Fraction(0.1, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.05, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.65, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.1, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.1, Size.AxesX(self.axes))] - - self.divider.set_horizontal(horiz) - self.divider.set_vertical(vert) - self.button_ny = len(vert) - 3 - # If we are going to add a colorbar it'll need an axis next to the plot if self.if_colorbar: - nx1 = -3 - self.cax = self.fig.add_axes((0., 0., 0.141, 1.)) - locator = self.divider.new_locator(nx=-2, ny=len(vert)-1, nx1=-1) - self.cax.set_axes_locator(locator) - else: - # Main figure spans all horiz and is in the top (2) in vert. - nx1 = -1 - - self.axes.set_axes_locator( - self.divider.new_locator(nx=0, ny=len(vert)-1, nx1=nx1)) + self.cax = self.axes.inset_axes([1.05, 0.0, 0.05, 1.0]) def _add_widgets(self): - self.buttons = [] - for i in range(0, self.num_buttons): - x = i * 2 - # The i+1/10. is a bug that if you make two axes directly on top of - # one another then the divider doesn't work. - self.buttons.append(self.fig.add_axes((0., 0., 0.+i/10., 1.))) - locator = self.divider.new_locator(nx=x, ny=self.button_ny) - self.buttons[-1].set_axes_locator(locator) - self.buttons[-1]._button = widgets.Button(self.buttons[-1], - self.button_labels[i]) - self.buttons[-1]._button.on_clicked(partial(self.button_func[i], self)) - - self.sliders = [] - self.slider_buttons = [] - for i in range(self.num_sliders): - y = i * 2 - self.sliders.append(self.fig.add_axes((0., 0., 0.01+i/10., 1.))) - if self.num_buttons == 0: - nx1 = 3 - else: - nx1 = -2 - locator = self.divider.new_locator(nx=2, ny=y, nx1=nx1) - self.sliders[-1].set_axes_locator(locator) - self.sliders[-1].text(0.5, 0.5, self.slider_labels[i], - transform=self.sliders[-1].transAxes, - horizontalalignment="center", - verticalalignment="center") - - sframe = widgets.Slider(self.sliders[-1], "", + # Add the custom button row + if self.num_buttons: + self.buttons = self.subfigs[1].subplots(ncols=self.num_buttons) + for i, bax in enumerate(self.buttons): + bax._button = widgets.Button(bax, + self.button_labels[i]) + bax._button.on_clicked(partial(self.button_func[i], self)) + + controls_axes = self.subfigs[2].subplots( + ncols=2, + nrows=self.num_sliders, + width_ratios=[0.1, 0.9], + squeeze=False, + ) + + self.slider_buttons = controls_axes[:, 0].tolist() + self.sliders = controls_axes[:, 1].tolist() + + for i, saxis in enumerate(self.sliders): + sframe = widgets.Slider(saxis, + "", # We add label manually self.slider_ranges[i][0], self.slider_ranges[i][-1]-1, valinit=self.slider_ranges[i][0], @@ -357,17 +323,27 @@ def _add_widgets(self): sframe.on_changed(partial(self._slider_changed, slider=sframe)) sframe.slider_ind = i sframe.cval = sframe.val - self.sliders[-1]._slider = sframe - - self.slider_buttons.append( - self.fig.add_axes((0., 0., 0.05+y/10., 1.))) - locator = self.divider.new_locator(nx=0, ny=y) - - self.slider_buttons[-1].set_axes_locator(locator) - butt = widgets.Button(self.slider_buttons[-1], ">") - butt.on_clicked(partial(self._click_slider_button, button=butt, slider=sframe)) + saxis._slider = sframe + + saxis.text( + 0.5, + 0.5, + self.slider_labels[i], + transform=saxis.transAxes, + horizontalalignment="center", + verticalalignment="center", + fontsize="x-small", + ) + + for i, sbaxis in enumerate(self.slider_buttons): + butt = widgets.Button(sbaxis, ">") + butt.on_clicked(partial( + self._click_slider_button, + button=butt, + slider=self.sliders[i]._slider + )) butt.clicked = False - self.slider_buttons[-1]._button = butt + sbaxis._button = butt # ============================================================================= # Widget callbacks diff --git a/mpl_animators/tests/test_basefuncanimator.py b/mpl_animators/tests/test_basefuncanimator.py index 58dac45..89d6ee1 100644 --- a/mpl_animators/tests/test_basefuncanimator.py +++ b/mpl_animators/tests/test_basefuncanimator.py @@ -219,4 +219,4 @@ def test_lineanimator_figure(): xdata = np.tile(np.linspace( 0, 100, (data_shape0[plot_axis0] + 1)), (data_shape0[slider_axis0], 1)) ani = LineAnimator(data0, plot_axis_index=plot_axis0, axis_ranges=[None, xdata]) - return ani.fig + return ani.parent_fig diff --git a/mpl_animators/tests/test_wcs.py b/mpl_animators/tests/test_wcs.py index 6e98d1c..4d6173e 100644 --- a/mpl_animators/tests/test_wcs.py +++ b/mpl_animators/tests/test_wcs.py @@ -131,14 +131,14 @@ def test_constructor_errors(wcs_4d): def test_array_animator_wcs_2d_simple_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y']) - return a.fig + return a.parent_fig @figure_test def test_array_animator_wcs_2d_clip_interval(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], clip_interval=(1, 99)*u.percent) - return a.fig + return a.parent_fig def test_array_animator_wcs_2d_clip_interval_change(wcs_4d): @@ -158,7 +158,7 @@ def test_array_animator_wcs_2d_clip_interval_change(wcs_4d): def test_array_animator_wcs_2d_celestial_sliders(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, ['x', 'y', 0, 0]) - return a.fig + return a.parent_fig def test_to_axes(wcs_4d): @@ -172,7 +172,7 @@ def test_array_animator_wcs_2d_update_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y']) a.update_plot(1, a.im, a.sliders[0]._slider) - return a.fig + return a.parent_fig @figure_test @@ -180,7 +180,7 @@ def test_array_animator_wcs_2d_transpose_update_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True) a.update_plot(1, a.im, a.sliders[0]._slider) - return a.fig + return a.parent_fig @figure_test @@ -191,7 +191,7 @@ def test_array_animator_wcs_2d_colorbar_buttons(wcs_4d): a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True, button_func=bf, button_labels=bl) a.update_plot(1, a.im, a.sliders[0]._slider) - return a.fig + return a.parent_fig @figure_test @@ -200,7 +200,7 @@ def test_array_animator_wcs_2d_colorbar_buttons_default_labels(wcs_4d): bf = [lambda x: x] * 10 a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True, button_func=bf) a.update_plot(1, a.im, a.sliders[0]._slider) - return a.fig + return a.parent_fig @figure_test @@ -216,7 +216,7 @@ def vmax_slider(val, im, slider): slider_functions=[vmin_slider, vmax_slider], slider_ranges=[[0, 100], [0, 100]]) a.update_plot(1, a.im, a.sliders[0]._slider) - return a.fig + return a.parent_fig @figure_test @@ -224,7 +224,7 @@ def test_array_animator_wcs_1d_update_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 0], ylabel="Y axis!") a.sliders[0]._slider.set_val(1) - return a.fig + return a.parent_fig @figure_test @@ -244,7 +244,7 @@ def test_array_animator_wcs_1d_update_plot_masked(wcs_3d): a = ArrayAnimatorWCS(data, wcs_3d, ['x', 0, 0], ylabel="Y axis!") a.sliders[0]._slider.set_val(wcs_3d.array_shape[0] / 2) - return a.fig + return a.parent_fig @figure_test @@ -261,7 +261,7 @@ def test_array_animator_wcs_coord_params(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params) - return a.fig + return a.parent_fig @figure_test @@ -278,7 +278,7 @@ def test_array_animator_wcs_coord_params_no_ticks(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params) - return a.fig + return a.parent_fig @figure_test @@ -295,4 +295,4 @@ def test_array_animator_wcs_coord_params_grid(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params) - return a.fig + return a.parent_fig From 0a5599eaaa12defec00e4ef156e4b00e56c67db5 Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Sun, 10 Nov 2024 21:02:47 +0000 Subject: [PATCH 2/2] More --- mpl_animators/base.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/mpl_animators/base.py b/mpl_animators/base.py index cac8d55..b385628 100644 --- a/mpl_animators/base.py +++ b/mpl_animators/base.py @@ -90,18 +90,6 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None, self.button_labels = button_labels or [] self.num_buttons = len(self.button_func) - if not fig: - fig = plt.figure(layout="constrained") - self.parent_fig = fig - self.subfigs = None - self._setup_subfigure_grid() - self.fig = self.subfigs[0] - - self.data = data - self.interval = interval - self.if_colorbar = colorbar - self.imshow_kwargs = kwargs - if len(slider_functions) != len(slider_ranges): raise ValueError("slider_functions and slider_ranges must be the same length.") @@ -114,6 +102,18 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None, self.slider_ranges = slider_ranges self.slider_labels = slider_labels or [''] * len(slider_functions) + if not fig: + fig = plt.figure(layout="constrained") + self.parent_fig = fig + self.subfigs = None + self._setup_subfigure_grid() + self.fig = self.subfigs[0] + + self.data = data + self.interval = interval + self.if_colorbar = colorbar + self.imshow_kwargs = kwargs + # Set active slider self.active_slider = 0 @@ -130,8 +130,8 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None, # # Only do this if figure has a manager, so directly constructed figures # (ie. via matplotlib.figure.Figure()) work. - if hasattr(self.fig.canvas, "manager") and self.fig.canvas.manager is not None: - plt.sca(self.axes) + # if isinstance(self.axes, mpl.axes.Axes) and hasattr(self.fig.canvas, "manager") and self.fig.canvas.manager is not None: + # plt.sca(self.axes) # Do Plot self.im = self.plot_start_image(self.axes) @@ -325,6 +325,7 @@ def _add_widgets(self): sframe.cval = sframe.val saxis._slider = sframe + # Add the label as text in the middle of the axis saxis.text( 0.5, 0.5,