Skip to content

Commit

Permalink
added tests stylesheet, deleted some hardcoded settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Dec 2, 2024
1 parent 1790d86 commit 8bf8a86
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 66 deletions.
4 changes: 4 additions & 0 deletions src/mpol/tests.mplstyle
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
image.cmap: inferno
figure.figsize: 7.1, 5.0
figure.autolayout: True
savefig.dpi: 200
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from mpol import coordinates, fourier, gridding, images, utils
from mpol.__init__ import zenodo_record

import matplotlib.pyplot as plt
plt.style.use("mpol.tests")

# private variables to this module
_npz_path = files("mpol.data").joinpath("mock_data.npz")
_nchan = 4
Expand Down
117 changes: 51 additions & 66 deletions test/images_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ def test_imagecube_tofits(coords, tmp_path):
def test_basecube_imagecube(coords, tmp_path):
# create a mock cube that includes negative values
nchan = 1
mean = torch.full(
(nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full(
(nchan, coords.npix, coords.npix), fill_value=0.5)
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

# tensor
base_cube = torch.normal(mean=mean, std=std)
Expand All @@ -83,7 +81,7 @@ def test_basecube_imagecube(coords, tmp_path):
plt.colorbar(im, ax=ax[1])
ax[1].set_title("mapped")

fig.savefig(tmp_path / "basecube_mapped.png", dpi=300)
fig.savefig(tmp_path / "basecube_mapped.png")

# try passing through ImageLayer
imagecube = images.ImageCube(coords=coords, nchan=nchan)
Expand All @@ -98,7 +96,7 @@ def test_basecube_imagecube(coords, tmp_path):
origin="lower",
interpolation="none",
)
fig.savefig(tmp_path / "imagecube.png", dpi=300)
fig.savefig(tmp_path / "imagecube.png")

plt.close("all")

Expand All @@ -108,10 +106,8 @@ def test_base_cube_conv_cube(coords, tmp_path):

# create a mock cube that includes negative values
nchan = 1
mean = torch.full(
(nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full(
(nchan, coords.npix, coords.npix), fill_value=0.5)
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

# The HannConvCube expects to function on a pre-packed ImageCube,
# so in order to get the plots looking correct on this test image,
Expand Down Expand Up @@ -141,7 +137,7 @@ def test_base_cube_conv_cube(coords, tmp_path):
plt.colorbar(im, ax=ax[1])
ax[1].set_title("convolved")

fig.savefig(tmp_path / "convcube.png", dpi=300)
fig.savefig(tmp_path / "convcube.png")

plt.close("all")

Expand All @@ -151,10 +147,8 @@ def test_multi_chan_conv(coords, tmp_path):
# and make sure that the HannConvCube works across channels

nchan = 10
mean = torch.full(
(nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full(
(nchan, coords.npix, coords.npix), fill_value=0.5)
mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5)
std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5)

# tensor
test_cube = torch.normal(mean=mean, std=std)
Expand All @@ -180,16 +174,15 @@ def test_plot_test_img(packed_cube, coords, tmp_path):

# put back to sky
sky_cube = utils.packed_cube_to_sky_cube(packed_cube)
im = ax.imshow(
sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax.imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
plt.colorbar(im)
fig.savefig(tmp_path / "sky_cube.png", dpi=300)
fig.savefig(tmp_path / "sky_cube.png")

plt.close("all")


def test_taper(coords, tmp_path):
for r in np.arange(0.0, 0.2, step=0.02):
for r in np.arange(0.0, 0.2, step=0.04):
fig, ax = plt.subplots(ncols=1)

taper_2D = images.uv_gaussian_taper(coords, r, r, 0.0)
Expand All @@ -205,37 +198,41 @@ def test_taper(coords, tmp_path):
)
plt.colorbar(im, ax=ax)

fig.savefig(tmp_path / f"taper{r:.2f}.png", dpi=300)
fig.savefig(tmp_path / f"taper{r:.2f}.png")

plt.close("all")


def test_gaussian_kernel(coords, tmp_path):
rs = np.array([0.02, 0.06, 0.10])
nchan = 3
fig, ax = plt.subplots(nrows=len(rs), ncols=nchan, figsize=(10,10))
for i,r in enumerate(rs):
fig, ax = plt.subplots(nrows=len(rs), ncols=nchan, figsize=(10, 10))
for i, r in enumerate(rs):
layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r)
weight = layer.m.weight.detach().numpy()
for j in range(nchan):
im = ax[i,j].imshow(weight[j,0], interpolation="none", origin="lower")
plt.colorbar(im, ax=ax[i,j])
im = ax[i, j].imshow(weight[j, 0], interpolation="none", origin="lower")
plt.colorbar(im, ax=ax[i, j])

fig.savefig(tmp_path / "filter.png", dpi=300)
fig.savefig(tmp_path / "filter.png")
plt.close("all")


def test_gaussian_kernel_rotate(coords, tmp_path):
r = 0.04
Omegas = [0, 20, 40] # degrees
Omegas = [0, 20, 40] # degrees
nchan = 3
fig, ax = plt.subplots(nrows=len(Omegas), ncols=nchan, figsize=(10, 10))
for i, Omega in enumerate(Omegas):
layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r, Omega=Omega)
layer = images.GaussConvImage(
coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r, Omega=Omega
)
weight = layer.m.weight.detach().numpy()
for j in range(nchan):
im = ax[i, j].imshow(weight[j, 0], interpolation="none",origin="lower")
im = ax[i, j].imshow(weight[j, 0], interpolation="none", origin="lower")
plt.colorbar(im, ax=ax[i, j])

fig.savefig(tmp_path / "filter.png", dpi=300)
fig.savefig(tmp_path / "filter.png")
plt.close("all")


Expand All @@ -245,71 +242,64 @@ def test_GaussConvImage(sky_cube, coords, tmp_path):
nchan = sky_cube.size()[0]

for r in np.arange(0.02, 0.11, step=0.04):

layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=r)

print("Kernel size", layer.m.weight.size())

fig, ax = plt.subplots(ncols=2)

im = ax[0].imshow(
sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)

im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])

c_sky = layer(sky_cube)
im = ax[1].imshow(
c_sky[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(c_sky[chan])
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / f"convolved_{r:.2f}.png", dpi=300)
fig.savefig(tmp_path / f"convolved_{r:.2f}.png")

plt.close("all")


def test_GaussConvImage_rotate(sky_cube, coords, tmp_path):
# show only the first channel
chan = 0
nchan = sky_cube.size()[0]

for Omega in [0, 30]:
layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=0.10, FWHM_min=0.05, Omega=Omega)
layer = images.GaussConvImage(
coords, nchan=nchan, FWHM_maj=0.10, FWHM_min=0.05, Omega=Omega
)

fig, ax = plt.subplots(ncols=2)

im = ax[0].imshow(
sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])

c_sky = layer(sky_cube)
im = ax[1].imshow(
c_sky[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(c_sky[chan])
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / f"convolved_{Omega:.0f}_deg.png", dpi=300)
fig.savefig(tmp_path / f"convolved_{Omega:.0f}_deg.png")

plt.close("all")


def test_GaussFourier(packed_cube, coords, tmp_path):
chan = 0

for FWHM in np.linspace(0.02, 0.5, num=10):
fig, ax = plt.subplots(ncols=2)
# put back to sky
sky_cube = utils.packed_cube_to_sky_cube(packed_cube)
im = ax[0].imshow(
sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])
Expand All @@ -329,38 +319,35 @@ def test_GaussFourier(packed_cube, coords, tmp_path):
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / "convolved_FWHM_{:.2f}.png".format(FWHM), dpi=300)
fig.savefig(tmp_path / "convolved_FWHM_{:.2f}.png".format(FWHM))

plt.close("all")


def test_GaussFourier_rotate(packed_cube, coords, tmp_path):
chan = 0

sky_cube = utils.packed_cube_to_sky_cube(packed_cube)

for Omega in [0, 20, 40]:
for Omega in [0, 30]:
layer = images.GaussConvFourier(
coords, FWHM_maj=0.16, FWHM_min=0.06, Omega=Omega
coords, FWHM_maj=0.10, FWHM_min=0.05, Omega=Omega
)

fig, ax = plt.subplots(ncols=2)

im = ax[0].imshow(
sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[0].imshow(sky_cube[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])

c_sky = layer(sky_cube)
im = ax[1].imshow(
c_sky[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[1].imshow(c_sky[chan], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(c_sky[chan])
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / f"convolved_{Omega:.2f}.png", dpi=300)
fig.savefig(tmp_path / f"convolved_{Omega:.0f}_deg.png")

plt.close("all")

Expand All @@ -370,14 +357,12 @@ def test_GaussFourier_point(coords, tmp_path):

# create an image with a point source in the center
sky_cube = torch.zeros((1, coords.npix, coords.npix))
cpix = coords.npix//2
sky_cube[0,cpix,cpix] = 1.0
cpix = coords.npix // 2
sky_cube[0, cpix, cpix] = 1.0

fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True)
# put back to sky
im = ax[0].imshow(
sky_cube[0], extent=coords.img_ext, origin="lower", cmap="inferno"
)
im = ax[0].imshow(sky_cube[0], extent=coords.img_ext, origin="lower")
flux = coords.cell_size**2 * torch.sum(sky_cube[0])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])
Expand All @@ -401,6 +386,6 @@ def test_GaussFourier_point(coords, tmp_path):
ax[1].set_ylim(-r, r)

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / "point_source_FWHM_{:.2f}.png".format(FWHM), dpi=300)
fig.savefig(tmp_path / "point_source_FWHM_{:.2f}.png".format(FWHM))

plt.close("all")

0 comments on commit 8bf8a86

Please sign in to comment.