From 4041caaf2b67700e502609d73ac0db9bb61a1003 Mon Sep 17 00:00:00 2001 From: Ian Czekala Date: Mon, 2 Dec 2024 20:30:40 +0000 Subject: [PATCH] added plot_utils and imshow_two routine. --- test/images_test.py | 24 +++----------- test/plot_utils.py | 76 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 19 deletions(-) create mode 100644 test/plot_utils.py diff --git a/test/images_test.py b/test/images_test.py index a798bd76..b6615415 100644 --- a/test/images_test.py +++ b/test/images_test.py @@ -4,22 +4,22 @@ import torch from astropy.io import fits from mpol import coordinates, images, plot, utils +from plot_utils import imshow_two - -def test_single_chan(): +def test_instantiate_single_chan(): coords = coordinates.GridCoords(cell_size=0.015, npix=800) im = images.ImageCube(coords=coords) assert im.nchan == 1 -def test_basecube_grad(): +def test_basecube_apply_grad(): coords = coordinates.GridCoords(cell_size=0.015, npix=800) bcube = images.BaseCube(coords=coords) loss = torch.sum(bcube()) loss.backward() -def test_imagecube_grad(coords): +def test_imagecube_apply_grad(coords): bcube = images.BaseCube(coords=coords) # try passing through ImageLayer imagecube = images.ImageCube(coords=coords) @@ -67,21 +67,7 @@ def test_basecube_imagecube(coords, tmp_path): # the default softplus function should map everything to positive values output = basecube() - fig, ax = plt.subplots(ncols=2, nrows=1) - - im = ax[0].imshow( - np.squeeze(base_cube.detach().numpy()), origin="lower", interpolation="none" - ) - plt.colorbar(im, ax=ax[0]) - ax[0].set_title("input") - - im = ax[1].imshow( - np.squeeze(output.detach().numpy()), origin="lower", interpolation="none" - ) - plt.colorbar(im, ax=ax[1]) - ax[1].set_title("mapped") - - fig.savefig(tmp_path / "basecube_mapped.png") + imshow_two(tmp_path / "basecube_mapped.png", [base_cube, output], titles=["input", "mapped"]) # try passing through ImageLayer imagecube = images.ImageCube(coords=coords, nchan=nchan) diff --git a/test/plot_utils.py b/test/plot_utils.py new file mode 100644 index 00000000..25c4e574 --- /dev/null +++ b/test/plot_utils.py @@ -0,0 +1,76 @@ +import matplotlib.pyplot as plt +import torch +import numpy as np + + +def imshow_two(path, imgs, **kwargs): + """Plot two images side by side, with scalebars. + + imgs is a list + Parameters + ---------- + path : string + path and filename to save figure + imgs : list + length-2 list of images to plot. Arguments are designed to be very permissive. If the image is a PyTorch tensor, the routine converts it to numpy, and then numpy.squeeze is called. + titles: list + if provided, list of strings corresponding to title for each subplot. + + + Returns + ------- + None + """ + + xx = 7.1 # in + rmargin = 0.8 + lmargin = 0.8 + tmargin = 0.3 + bmargin = 0.5 + middle_sep = 1.2 + ax_width = (xx - rmargin - lmargin - middle_sep) / 2 + ax_height = ax_width + cax_width = 0.1 + cax_sep = 0.15 + cax_height = ax_height + yy = bmargin + ax_height + tmargin + + fig = plt.figure(figsize=(xx, yy)) + + ax = [] + cax = [] + for i in [0, 1]: + ax.append( + fig.add_axes( + [ + (lmargin + i * (ax_width + middle_sep)) / xx, + bmargin / yy, + ax_width / xx, + ax_height / yy, + ] + ) + ) + cax.append( + fig.add_axes( + ( + [ + (lmargin + (i + 1) * ax_width + i * middle_sep + cax_sep) / xx, + bmargin / yy, + cax_width / xx, + cax_height / yy, + ] + ) + ) + ) + + img = imgs[i] + img = img.detach().numpy() if torch.is_tensor(img) else img + + im = ax[i].imshow(np.squeeze(img), origin="lower", interpolation="none") + plt.colorbar(im, cax=cax[i]) + + if "titles" in kwargs: + ax[i].set_title(kwargs["titles"][i]) + + + fig.savefig(path)