From a1b053562759fd8138f7995169549a8a4af4209c Mon Sep 17 00:00:00 2001 From: Ian Czekala Date: Wed, 4 Dec 2024 14:45:02 +0000 Subject: [PATCH] fleshed out imshow_two, parameterized GaussConv tests. --- src/mpol/tests.mplstyle | 2 +- test/images_test.py | 221 +++++++++++++++------------------------- test/plot_utils.py | 130 +++++++++++++++++------ 3 files changed, 183 insertions(+), 170 deletions(-) diff --git a/src/mpol/tests.mplstyle b/src/mpol/tests.mplstyle index ed4059c6..429e682b 100644 --- a/src/mpol/tests.mplstyle +++ b/src/mpol/tests.mplstyle @@ -1,4 +1,4 @@ image.cmap: inferno figure.figsize: 7.1, 5.0 figure.autolayout: True -savefig.dpi: 200 \ No newline at end of file +savefig.dpi: 300 \ No newline at end of file diff --git a/test/images_test.py b/test/images_test.py index 7d300b5b..81f81b72 100644 --- a/test/images_test.py +++ b/test/images_test.py @@ -6,146 +6,97 @@ from mpol import coordinates, images, plot, utils from plot_utils import imshow_two -def test_instantiate_single_chan(): - coords = coordinates.GridCoords(cell_size=0.015, npix=800) - im = images.ImageCube(coords=coords) - assert im.nchan == 1 +def test_BaseCube_map(coords, tmp_path): + # create a mock cube that includes negative values + nchan = 1 + mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) + std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) + + bcube = torch.normal(mean=mean, std=std) + blayer = images.BaseCube(coords=coords, nchan=nchan, base_cube=bcube) + + # the default softplus function should map everything to positive values + blayer_output = blayer() + + imshow_two( + tmp_path / "BaseCube_mapped.png", + [bcube, blayer_output], + title=["BaseCube input", "BaseCube output"], + xlabel=["pixel"], + ylabel=["pixel"], + ) + + assert torch.all(blayer_output >= 0) -def test_basecube_apply_grad(): + +def test_instantiate_ImageCube(): coords = coordinates.GridCoords(cell_size=0.015, npix=800) - bcube = images.BaseCube(coords=coords) - loss = torch.sum(bcube()) - loss.backward() + im = images.ImageCube(coords=coords) + assert im.nchan == 1 -def test_imagecube_apply_grad(coords): +def test_ImageCube_apply_grad(coords): bcube = images.BaseCube(coords=coords) - # try passing through ImageLayer imagecube = images.ImageCube(coords=coords) - - # send things through this layer loss = torch.sum(imagecube(bcube())) - loss.backward() -# test for proper fits scale -def test_imagecube_tofits(coords, tmp_path): - # creating base cube +def test_to_FITS_pixel_scale(coords, tmp_path): + """Test whether the FITS scale was written correctly.""" bcube = images.BaseCube(coords=coords) - - # try passing through ImageLayer imagecube = images.ImageCube(coords=coords) - - # sending the basecube through the imagecube imagecube(bcube()) - # creating output fits file with name 'test_cube_fits_file39.fits' - # file will be deleted after testing + # write FITS to file imagecube.to_FITS(fname=tmp_path / "test_cube_fits_file39.fits", overwrite=True) - # inputting the header from the previously created fits file + # read file and check pixel scale is correct fits_header = fits.open(tmp_path / "test_cube_fits_file39.fits")[0].header assert (fits_header["CDELT1"] and fits_header["CDELT2"]) == pytest.approx( coords.cell_size / 3600 ) -def test_basecube_imagecube(coords, tmp_path): - # create a mock cube that includes negative values - nchan = 1 - mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) - std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) - - # tensor - base_cube = torch.normal(mean=mean, std=std) - - # layer - basecube = images.BaseCube(coords=coords, nchan=nchan, base_cube=base_cube) - - # the default softplus function should map everything to positive values - output = basecube() - - imshow_two(tmp_path / "basecube_mapped.png", [base_cube, output], titles=["input", "mapped"]) - - # try passing through ImageLayer - imagecube = images.ImageCube(coords=coords, nchan=nchan) - - # send things through this layer - imagecube(basecube()) - - fig, ax = plt.subplots(ncols=1) - ax.imshow( - np.squeeze(imagecube.sky_cube.detach().numpy()), - extent=imagecube.coords.img_ext, - origin="lower", - interpolation="none", - ) - fig.savefig(tmp_path / "imagecube.png") - - plt.close("all") - - -def test_base_cube_conv_cube(coords, tmp_path): - # test whether the HannConvCube functions appropriately - +def test_HannConvCube(coords, tmp_path): # create a mock cube that includes negative values nchan = 1 mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) # The HannConvCube expects to function on a pre-packed ImageCube, - # so in order to get the plots looking correct on this test image, - # we need to faff around with packing - - # tensor test_cube = torch.normal(mean=mean, std=std) test_cube_packed = utils.sky_cube_to_packed_cube(test_cube) - # layer conv_layer = images.HannConvCube(nchan=nchan) conv_output_packed = conv_layer(test_cube_packed) conv_output = utils.packed_cube_to_sky_cube(conv_output_packed) - fig, ax = plt.subplots(ncols=2, nrows=1) - - im = ax[0].imshow( - np.squeeze(test_cube.detach().numpy()), origin="lower", interpolation="none" + imshow_two( + tmp_path / "convcube.png", + [test_cube, conv_output], + title=["input", "convolved"], + xlabel=["pixel"], + ylabel=["pixel"], ) - plt.colorbar(im, ax=ax[0]) - ax[0].set_title("input") - - im = ax[1].imshow( - np.squeeze(conv_output.detach().numpy()), origin="lower", interpolation="none" - ) - plt.colorbar(im, ax=ax[1]) - ax[1].set_title("convolved") - fig.savefig(tmp_path / "convcube.png") - - plt.close("all") - - -def test_multi_chan_conv(coords, tmp_path): - # create a mock channel cube that includes negative values - # and make sure that the HannConvCube works across channels +def test_HannConvCube_multi_chan(coords): + """Make sure HannConvCube functions with multi-channeled input""" nchan = 10 mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) - # tensor test_cube = torch.normal(mean=mean, std=std) - # layer conv_layer = images.HannConvCube(nchan=nchan) - conv_layer(test_cube) -def test_image_flux(coords): +def test_flux(coords): + """Make sure we can read the flux attribute.""" nchan = 20 bcube = images.BaseCube(coords=coords, nchan=nchan) im = images.ImageCube(coords=coords, nchan=nchan) @@ -167,7 +118,7 @@ def test_plot_test_img(packed_cube, coords, tmp_path): plt.close("all") -def test_taper(coords, tmp_path): +def test_uv_gaussian_taper(coords, tmp_path): for r in np.arange(0.0, 0.2, step=0.04): fig, ax = plt.subplots(ncols=1) @@ -189,7 +140,7 @@ def test_taper(coords, tmp_path): plt.close("all") -def test_gaussian_kernel(coords, tmp_path): +def test_GaussConvImage_kernel(coords, tmp_path): rs = np.array([0.02, 0.06, 0.10]) nchan = 3 fig, ax = plt.subplots(nrows=len(rs), ncols=nchan, figsize=(10, 10)) @@ -204,7 +155,7 @@ def test_gaussian_kernel(coords, tmp_path): plt.close("all") -def test_gaussian_kernel_rotate(coords, tmp_path): +def test_GaussConvImage_kernel_rotate(coords, tmp_path): r = 0.04 Omegas = [0, 20, 40] # degrees nchan = 3 @@ -221,62 +172,56 @@ def test_gaussian_kernel_rotate(coords, tmp_path): fig.savefig(tmp_path / "filter.png") plt.close("all") - -def test_GaussConvImage(sky_cube, coords, tmp_path): - # show only the first channel +@pytest.mark.parametrize("r", [0.02, 0.06, 0.1]) +def test_GaussConvImage(sky_cube, coords, tmp_path, r): chan = 0 nchan = sky_cube.size()[0] - for r in np.arange(0.02, 0.11, step=0.04): - layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=r) - - print("Kernel size", layer.m.weight.size()) - - fig, ax = plt.subplots(ncols=2) + layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=r) + c_sky = layer(sky_cube) - im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower") - flux = coords.cell_size**2 * torch.sum(sky_cube[chan]) - ax[0].set_title(f"tot flux: {flux:.3f} Jy") - plt.colorbar(im, ax=ax[0]) - - c_sky = layer(sky_cube) - im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower") - flux = coords.cell_size**2 * torch.sum(c_sky[chan]) - ax[1].set_title(f"tot flux: {flux:.3f} Jy") - - plt.colorbar(im, ax=ax[1]) - fig.savefig(tmp_path / f"convolved_{r:.2f}.png") - - plt.close("all") + imgs = [sky_cube[chan], c_sky[chan]] + fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs] + title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes] + imshow_two( + tmp_path / f"convolved_{r:.2f}.png", + imgs, + sky=True, + suptitle=f"Image Plane Gauss Convolution FWHM={r}", + title=title, + extent=[coords.img_ext] + ) -def test_GaussConvImage_rotate(sky_cube, coords, tmp_path): - # show only the first channel + assert pytest.approx(fluxes[0]) == fluxes[1] + +@pytest.mark.parametrize("Omega", [0, 30]) +def test_GaussConvImage_rotate(sky_cube, coords, tmp_path, Omega): chan = 0 nchan = sky_cube.size()[0] - for Omega in [0, 30]: - layer = images.GaussConvImage( - coords, nchan=nchan, FWHM_maj=0.10, FWHM_min=0.05, Omega=Omega - ) - - fig, ax = plt.subplots(ncols=2) - - im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower") - flux = coords.cell_size**2 * torch.sum(sky_cube[chan]) - ax[0].set_title(f"tot flux: {flux:.3f} Jy") - plt.colorbar(im, ax=ax[0]) + FWHM_maj = 0.10 + FWHM_min = 0.05 - c_sky = layer(sky_cube) - im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower") - flux = coords.cell_size**2 * torch.sum(c_sky[chan]) - ax[1].set_title(f"tot flux: {flux:.3f} Jy") - - plt.colorbar(im, ax=ax[1]) - fig.savefig(tmp_path / f"convolved_{Omega:.0f}_deg.png") - - plt.close("all") + layer = images.GaussConvImage( + coords, nchan=nchan, FWHM_maj=FWHM_maj, FWHM_min=FWHM_min, Omega=Omega + ) + c_sky = layer(sky_cube) + + imgs = [sky_cube[chan], c_sky[chan]] + fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs] + title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes] + + imshow_two( + tmp_path / f"convolved_{Omega:.0f}_deg.png", + imgs, + sky=True, + suptitle=r'Image Plane Gauss Convolution: $\Omega$=' + f'{Omega}, {FWHM_maj}", {FWHM_min}"', + title=title, + extent=[coords.img_ext], + ) + assert pytest.approx(fluxes[0], abs=4e-7) == fluxes[1] def test_GaussFourier(packed_cube, coords, tmp_path): chan = 0 diff --git a/test/plot_utils.py b/test/plot_utils.py index 25c4e574..39294f7a 100644 --- a/test/plot_utils.py +++ b/test/plot_utils.py @@ -1,9 +1,55 @@ +import matplotlib as mpl import matplotlib.pyplot as plt import torch import numpy as np -def imshow_two(path, imgs, **kwargs): +def extend_list(l, num=2): + """ + Duplicate or extend a list to two items. + + l: list + the list of items to potentially duplicate or truncate. + num: int + the final length of the list + + Returns + ------- + list + Length num list of items. + + Examples + -------- + >>> extend_list(["L Plot", "R Plot"]) + ["L Plot", "R Plot"] + >>> extend_list({["Plot"]) # both L and R will have "Plot" + ["Plot", "Plot"] + >>> extend_list({["L Plot", "R Plot", "Z Plot"]}) # "Z Plot" is ignored + ["L Plot", "R Plot"] + """ + if len(l) == 1: + return num * l + else: + return l[:num] + +def extend_kwargs(kwargs): + """ + This is a helper routine for imshow_two, designed to flexibly consume a variety + of options for each of the two plots. + + kwargs: dict + the kwargs dict provided from the function call + + Returns + ------- + dict + Updated kwargs with length 2 lists of items. + """ + + for key, item in kwargs.items(): + kwargs[key] = extend_list(item) + +def imshow_two(path, imgs, sky=False, suptitle=None, **kwargs): """Plot two images side by side, with scalebars. imgs is a list @@ -13,8 +59,12 @@ def imshow_two(path, imgs, **kwargs): path and filename to save figure imgs : list length-2 list of images to plot. Arguments are designed to be very permissive. If the image is a PyTorch tensor, the routine converts it to numpy, and then numpy.squeeze is called. - titles: list - if provided, list of strings corresponding to title for each subplot. + sky: bool + If True, treat images as sky plots and label with offset arcseconds. + title: list + if provided, list of strings corresponding to title for each subplot. If only one provided, + xlabel: list + if provided, list of strings Returns @@ -22,12 +72,12 @@ def imshow_two(path, imgs, **kwargs): None """ - xx = 7.1 # in + xx = 7.5 # in rmargin = 0.8 lmargin = 0.8 - tmargin = 0.3 + tmargin = 0.3 if suptitle is None else 0.5 bmargin = 0.5 - middle_sep = 1.2 + middle_sep = 1.3 ax_width = (xx - rmargin - lmargin - middle_sep) / 2 ax_height = ax_width cax_width = 0.1 @@ -35,23 +85,29 @@ def imshow_two(path, imgs, **kwargs): cax_height = ax_height yy = bmargin + ax_height + tmargin - fig = plt.figure(figsize=(xx, yy)) + with mpl.rc_context({'figure.autolayout': False}): + fig = plt.figure(figsize=(xx, yy)) - ax = [] - cax = [] - for i in [0, 1]: - ax.append( - fig.add_axes( - [ - (lmargin + i * (ax_width + middle_sep)) / xx, - bmargin / yy, - ax_width / xx, - ax_height / yy, - ] - ) - ) - cax.append( - fig.add_axes( + ax = [] + cax = [] + + extend_kwargs(kwargs) + + if "extent" not in kwargs: + kwargs["extent"] = [None, None] + + for i in [0, 1]: + a = fig.add_axes( + [ + (lmargin + i * (ax_width + middle_sep)) / xx, + bmargin / yy, + ax_width / xx, + ax_height / yy, + ] + ) + ax.append(a) + + ca = fig.add_axes( ( [ (lmargin + (i + 1) * ax_width + i * middle_sep + cax_sep) / xx, @@ -61,16 +117,28 @@ def imshow_two(path, imgs, **kwargs): ] ) ) - ) + cax.append(ca) + + img = imgs[i] + img = img.detach().numpy() if torch.is_tensor(img) else img - img = imgs[i] - img = img.detach().numpy() if torch.is_tensor(img) else img + im = a.imshow(np.squeeze(img), origin="lower", interpolation="none", extent=kwargs["extent"][i]) + plt.colorbar(im, cax=ca) + + if "title" in kwargs: + a.set_title(kwargs["title"][i]) - im = ax[i].imshow(np.squeeze(img), origin="lower", interpolation="none") - plt.colorbar(im, cax=cax[i]) - - if "titles" in kwargs: - ax[i].set_title(kwargs["titles"][i]) + if sky: + a.set_xlabel(r"$\Delta \alpha\ \cos \delta\;[{}^{\prime\prime}]$") + a.set_ylabel(r"$\Delta \delta\;[{}^{\prime\prime}]$") + else: + if "xlabel" in kwargs: + a.set_xlabel(kwargs["xlabel"][i]) + if "ylabel" in kwargs: + a.set_ylabel(kwargs["ylabel"][i]) - fig.savefig(path) + if suptitle is not None: + fig.suptitle(suptitle) + fig.savefig(path) + plt.close("all")