Skip to content

Commit

Permalink
added plot_utils and imshow_two routine.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Dec 2, 2024
1 parent 8bf8a86 commit 4041caa
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 19 deletions.
24 changes: 5 additions & 19 deletions test/images_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
import torch
from astropy.io import fits
from mpol import coordinates, images, plot, utils
from plot_utils import imshow_two


def test_single_chan():
def test_instantiate_single_chan():
coords = coordinates.GridCoords(cell_size=0.015, npix=800)
im = images.ImageCube(coords=coords)
assert im.nchan == 1


def test_basecube_grad():
def test_basecube_apply_grad():
coords = coordinates.GridCoords(cell_size=0.015, npix=800)
bcube = images.BaseCube(coords=coords)
loss = torch.sum(bcube())
loss.backward()


def test_imagecube_grad(coords):
def test_imagecube_apply_grad(coords):
bcube = images.BaseCube(coords=coords)
# try passing through ImageLayer
imagecube = images.ImageCube(coords=coords)
Expand Down Expand Up @@ -67,21 +67,7 @@ def test_basecube_imagecube(coords, tmp_path):
# the default softplus function should map everything to positive values
output = basecube()

fig, ax = plt.subplots(ncols=2, nrows=1)

im = ax[0].imshow(
np.squeeze(base_cube.detach().numpy()), origin="lower", interpolation="none"
)
plt.colorbar(im, ax=ax[0])
ax[0].set_title("input")

im = ax[1].imshow(
np.squeeze(output.detach().numpy()), origin="lower", interpolation="none"
)
plt.colorbar(im, ax=ax[1])
ax[1].set_title("mapped")

fig.savefig(tmp_path / "basecube_mapped.png")
imshow_two(tmp_path / "basecube_mapped.png", [base_cube, output], titles=["input", "mapped"])

# try passing through ImageLayer
imagecube = images.ImageCube(coords=coords, nchan=nchan)
Expand Down
76 changes: 76 additions & 0 deletions test/plot_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import matplotlib.pyplot as plt
import torch
import numpy as np


def imshow_two(path, imgs, **kwargs):
"""Plot two images side by side, with scalebars.
imgs is a list
Parameters
----------
path : string
path and filename to save figure
imgs : list
length-2 list of images to plot. Arguments are designed to be very permissive. If the image is a PyTorch tensor, the routine converts it to numpy, and then numpy.squeeze is called.
titles: list
if provided, list of strings corresponding to title for each subplot.
Returns
-------
None
"""

xx = 7.1 # in
rmargin = 0.8
lmargin = 0.8
tmargin = 0.3
bmargin = 0.5
middle_sep = 1.2
ax_width = (xx - rmargin - lmargin - middle_sep) / 2
ax_height = ax_width
cax_width = 0.1
cax_sep = 0.15
cax_height = ax_height
yy = bmargin + ax_height + tmargin

fig = plt.figure(figsize=(xx, yy))

ax = []
cax = []
for i in [0, 1]:
ax.append(
fig.add_axes(
[
(lmargin + i * (ax_width + middle_sep)) / xx,
bmargin / yy,
ax_width / xx,
ax_height / yy,
]
)
)
cax.append(
fig.add_axes(
(
[
(lmargin + (i + 1) * ax_width + i * middle_sep + cax_sep) / xx,
bmargin / yy,
cax_width / xx,
cax_height / yy,
]
)
)
)

img = imgs[i]
img = img.detach().numpy() if torch.is_tensor(img) else img

im = ax[i].imshow(np.squeeze(img), origin="lower", interpolation="none")
plt.colorbar(im, cax=cax[i])

if "titles" in kwargs:
ax[i].set_title(kwargs["titles"][i])


fig.savefig(path)

0 comments on commit 4041caa

Please sign in to comment.