Skip to content

Commit

Permalink
fleshed out imshow_two, parameterized GaussConv tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Dec 4, 2024
1 parent 315a4fb commit a1b0535
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 170 deletions.
2 changes: 1 addition & 1 deletion src/mpol/tests.mplstyle
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image.cmap: inferno
figure.figsize: 7.1, 5.0
figure.autolayout: True
savefig.dpi: 200
savefig.dpi: 300
221 changes: 83 additions & 138 deletions test/images_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,146 +6,97 @@
from mpol import coordinates, images, plot, utils
from plot_utils import imshow_two

def test_instantiate_single_chan():
coords = coordinates.GridCoords(cell_size=0.015, npix=800)
im = images.ImageCube(coords=coords)
assert im.nchan == 1

def test_BaseCube_map(coords, tmp_path):
# create a mock cube that includes negative values
nchan = 1
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

bcube = torch.normal(mean=mean, std=std)
blayer = images.BaseCube(coords=coords, nchan=nchan, base_cube=bcube)

# the default softplus function should map everything to positive values
blayer_output = blayer()

imshow_two(
tmp_path / "BaseCube_mapped.png",
[bcube, blayer_output],
title=["BaseCube input", "BaseCube output"],
xlabel=["pixel"],
ylabel=["pixel"],
)

assert torch.all(blayer_output >= 0)

def test_basecube_apply_grad():

def test_instantiate_ImageCube():
coords = coordinates.GridCoords(cell_size=0.015, npix=800)
bcube = images.BaseCube(coords=coords)
loss = torch.sum(bcube())
loss.backward()
im = images.ImageCube(coords=coords)
assert im.nchan == 1


def test_imagecube_apply_grad(coords):
def test_ImageCube_apply_grad(coords):
bcube = images.BaseCube(coords=coords)
# try passing through ImageLayer
imagecube = images.ImageCube(coords=coords)

# send things through this layer
loss = torch.sum(imagecube(bcube()))

loss.backward()


# test for proper fits scale
def test_imagecube_tofits(coords, tmp_path):
# creating base cube
def test_to_FITS_pixel_scale(coords, tmp_path):
"""Test whether the FITS scale was written correctly."""
bcube = images.BaseCube(coords=coords)

# try passing through ImageLayer
imagecube = images.ImageCube(coords=coords)

# sending the basecube through the imagecube
imagecube(bcube())

# creating output fits file with name 'test_cube_fits_file39.fits'
# file will be deleted after testing
# write FITS to file
imagecube.to_FITS(fname=tmp_path / "test_cube_fits_file39.fits", overwrite=True)

# inputting the header from the previously created fits file
# read file and check pixel scale is correct
fits_header = fits.open(tmp_path / "test_cube_fits_file39.fits")[0].header
assert (fits_header["CDELT1"] and fits_header["CDELT2"]) == pytest.approx(
coords.cell_size / 3600
)


def test_basecube_imagecube(coords, tmp_path):
# create a mock cube that includes negative values
nchan = 1
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

# tensor
base_cube = torch.normal(mean=mean, std=std)

# layer
basecube = images.BaseCube(coords=coords, nchan=nchan, base_cube=base_cube)

# the default softplus function should map everything to positive values
output = basecube()

imshow_two(tmp_path / "basecube_mapped.png", [base_cube, output], titles=["input", "mapped"])

# try passing through ImageLayer
imagecube = images.ImageCube(coords=coords, nchan=nchan)

# send things through this layer
imagecube(basecube())

fig, ax = plt.subplots(ncols=1)
ax.imshow(
np.squeeze(imagecube.sky_cube.detach().numpy()),
extent=imagecube.coords.img_ext,
origin="lower",
interpolation="none",
)
fig.savefig(tmp_path / "imagecube.png")

plt.close("all")


def test_base_cube_conv_cube(coords, tmp_path):
# test whether the HannConvCube functions appropriately

def test_HannConvCube(coords, tmp_path):
# create a mock cube that includes negative values
nchan = 1
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

# The HannConvCube expects to function on a pre-packed ImageCube,
# so in order to get the plots looking correct on this test image,
# we need to faff around with packing

# tensor
test_cube = torch.normal(mean=mean, std=std)
test_cube_packed = utils.sky_cube_to_packed_cube(test_cube)

# layer
conv_layer = images.HannConvCube(nchan=nchan)

conv_output_packed = conv_layer(test_cube_packed)
conv_output = utils.packed_cube_to_sky_cube(conv_output_packed)

fig, ax = plt.subplots(ncols=2, nrows=1)

im = ax[0].imshow(
np.squeeze(test_cube.detach().numpy()), origin="lower", interpolation="none"
imshow_two(
tmp_path / "convcube.png",
[test_cube, conv_output],
title=["input", "convolved"],
xlabel=["pixel"],
ylabel=["pixel"],
)
plt.colorbar(im, ax=ax[0])
ax[0].set_title("input")

im = ax[1].imshow(
np.squeeze(conv_output.detach().numpy()), origin="lower", interpolation="none"
)
plt.colorbar(im, ax=ax[1])
ax[1].set_title("convolved")

fig.savefig(tmp_path / "convcube.png")

plt.close("all")


def test_multi_chan_conv(coords, tmp_path):
# create a mock channel cube that includes negative values
# and make sure that the HannConvCube works across channels

def test_HannConvCube_multi_chan(coords):
"""Make sure HannConvCube functions with multi-channeled input"""
nchan = 10
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

# tensor
test_cube = torch.normal(mean=mean, std=std)

# layer
conv_layer = images.HannConvCube(nchan=nchan)

conv_layer(test_cube)


def test_image_flux(coords):
def test_flux(coords):
"""Make sure we can read the flux attribute."""
nchan = 20
bcube = images.BaseCube(coords=coords, nchan=nchan)
im = images.ImageCube(coords=coords, nchan=nchan)
Expand All @@ -167,7 +118,7 @@ def test_plot_test_img(packed_cube, coords, tmp_path):
plt.close("all")


def test_taper(coords, tmp_path):
def test_uv_gaussian_taper(coords, tmp_path):
for r in np.arange(0.0, 0.2, step=0.04):
fig, ax = plt.subplots(ncols=1)

Expand All @@ -189,7 +140,7 @@ def test_taper(coords, tmp_path):
plt.close("all")


def test_gaussian_kernel(coords, tmp_path):
def test_GaussConvImage_kernel(coords, tmp_path):
rs = np.array([0.02, 0.06, 0.10])
nchan = 3
fig, ax = plt.subplots(nrows=len(rs), ncols=nchan, figsize=(10, 10))
Expand All @@ -204,7 +155,7 @@ def test_gaussian_kernel(coords, tmp_path):
plt.close("all")


def test_gaussian_kernel_rotate(coords, tmp_path):
def test_GaussConvImage_kernel_rotate(coords, tmp_path):
r = 0.04
Omegas = [0, 20, 40] # degrees
nchan = 3
Expand All @@ -221,62 +172,56 @@ def test_gaussian_kernel_rotate(coords, tmp_path):
fig.savefig(tmp_path / "filter.png")
plt.close("all")


def test_GaussConvImage(sky_cube, coords, tmp_path):
# show only the first channel
@pytest.mark.parametrize("r", [0.02, 0.06, 0.1])
def test_GaussConvImage(sky_cube, coords, tmp_path, r):
chan = 0
nchan = sky_cube.size()[0]

for r in np.arange(0.02, 0.11, step=0.04):
layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=r)

print("Kernel size", layer.m.weight.size())

fig, ax = plt.subplots(ncols=2)
layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=r)
c_sky = layer(sky_cube)

im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])

c_sky = layer(sky_cube)
im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(c_sky[chan])
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / f"convolved_{r:.2f}.png")

plt.close("all")
imgs = [sky_cube[chan], c_sky[chan]]
fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs]
title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes]

imshow_two(
tmp_path / f"convolved_{r:.2f}.png",
imgs,
sky=True,
suptitle=f"Image Plane Gauss Convolution FWHM={r}",
title=title,
extent=[coords.img_ext]
)

def test_GaussConvImage_rotate(sky_cube, coords, tmp_path):
# show only the first channel
assert pytest.approx(fluxes[0]) == fluxes[1]

@pytest.mark.parametrize("Omega", [0, 30])
def test_GaussConvImage_rotate(sky_cube, coords, tmp_path, Omega):
chan = 0
nchan = sky_cube.size()[0]

for Omega in [0, 30]:
layer = images.GaussConvImage(
coords, nchan=nchan, FWHM_maj=0.10, FWHM_min=0.05, Omega=Omega
)

fig, ax = plt.subplots(ncols=2)

im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])
FWHM_maj = 0.10
FWHM_min = 0.05

c_sky = layer(sky_cube)
im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(c_sky[chan])
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / f"convolved_{Omega:.0f}_deg.png")

plt.close("all")
layer = images.GaussConvImage(
coords, nchan=nchan, FWHM_maj=FWHM_maj, FWHM_min=FWHM_min, Omega=Omega
)
c_sky = layer(sky_cube)

imgs = [sky_cube[chan], c_sky[chan]]
fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs]
title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes]

imshow_two(
tmp_path / f"convolved_{Omega:.0f}_deg.png",
imgs,
sky=True,
suptitle=r'Image Plane Gauss Convolution: $\Omega$=' + f'{Omega}, {FWHM_maj}", {FWHM_min}"',
title=title,
extent=[coords.img_ext],
)

assert pytest.approx(fluxes[0], abs=4e-7) == fluxes[1]

def test_GaussFourier(packed_cube, coords, tmp_path):
chan = 0
Expand Down
Loading

0 comments on commit a1b0535

Please sign in to comment.