Skip to content

Commit

Permalink
consolidated Gauss routines.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Mar 6, 2024
1 parent 1b3271a commit 4c7b453
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 295 deletions.
334 changes: 97 additions & 237 deletions src/mpol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from collections.abc import Callable

import numpy as np
from typing import Any
import numpy.typing as npt

import torch
import torch.fft # to avoid conflicts with old torch.fft *function*
from torch import nn
Expand Down Expand Up @@ -188,176 +191,13 @@ def forward(self, cube: torch.Tensor) -> torch.Tensor:
return utils.sky_cube_to_packed_cube(conv_sky_cube)


class GaussBaseBeam(nn.Module):
class GaussConvImage(nn.Module):
r"""
This layer will convolve the base cube with a Gaussian beam of variable resolution.
The FWHM of the beam (in arcsec) is a trainable parameter of the layer.
Parameters
----------
coords : :class:`mpol.coordinates.GridCoords`
an object instantiated from the GridCoords class, containing information about
the image `cell_size` and `npix`.
FWHM: float, units of arcsec
the FWHH of the Gaussian
"""

def __init__(self, coords: GridCoords, FWHM: float) -> None:
super().__init__()

self.coords = coords
self.FWHM = FWHM

# convert FWHM to sigma and to radians
FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2)))
sigma = self.FWHM * FWHM2sigma * constants.arcsec # radians

# calculate the UV taper from the FWHM size.
u = self.coords.packed_u_centers_2D
v = self.coords.packed_v_centers_2D

taper_2D = np.exp(-2 * np.pi**2 * (sigma**2 * u**2 + sigma**2 * v**2))

# store taper to register so it transfers to GPU
self.register_buffer("taper_2D", torch.tensor(taper_2D, dtype=torch.float32))

def forward(self, packed_cube):
r"""
Convolve a packed_cube image with a 2D Gaussian PSF. Operation is carried out
in the Fourier domain using a Gaussian taper.
Parameters
----------
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
Returns
-------
:class:`torch.Tensor`
The convolved cube in packed format.
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (
(npix_m == self.coords.npix) and (npix_l == self.coords.npix)
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
packed_cube.size(), self.coords.npix
)

# in FFT packed format
# we're round-tripping, so we can ignore prefactors for correctness
# calling this `vis_like`, since it's not actually the vis
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))

# apply taper to packed image
tapered_vis = vis_like * torch.broadcast_to(self.taper_2D, packed_cube.size())

# iFFT back, ignoring prefactors for round-trip
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
torch.max(convolved_packed_cube.imag), thresh
)

r_cube: torch.Tensor = convolved_packed_cube.real
return r_cube


class GaussBaseBeamTunable(nn.Module):
r"""
This layer will convolve the base cube with a Gaussian beam of variable resolution.
The FWHM of the beam (in arcsec) is a trainable parameter of the layer.
Parameters
----------
coords : :class:`mpol.coordinates.GridCoords`
an object instantiated from the GridCoords class, containing information about
the image `cell_size` and `npix`.
"""

def __init__(self, coords: GridCoords) -> None:
super().__init__()

self.coords = coords

self._FWHM_base = nn.Parameter(torch.tensor([-3.0]))
self.softplus = nn.Softplus()
# -3.0 corresponds to about 0.05 arcsec

# store coordinates to register so they transfer to GPU
self.register_buffer(
"u", torch.tensor(self.coords.packed_u_centers_2D, dtype=torch.float32)
)
self.register_buffer(
"v", torch.tensor(self.coords.packed_v_centers_2D, dtype=torch.float32)
)

@property
def FWHM(self):
r"""Map from base parameter to actual FWHM."""
return self.softplus(self._FWHM_base) # ensures always positive

def forward(self, packed_cube):
r"""
Convolve a packed_cube image with a 2D Gaussian PSF. Operation is carried out
in the Fourier domain using a Gaussian taper.
Parameters
----------
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
Returns
-------
:class:`torch.Tensor`
The convolved cube in packed format.
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (
(npix_m == self.coords.npix) and (npix_l == self.coords.npix)
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
packed_cube.size(), self.coords.npix
)

# in FFT packed format
# we're round-tripping, so we can ignore prefactors for correctness
# calling this `vis_like`, since it's not actually the vis
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))

# convert FWHM to sigma and to radians
FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2)))
sigma = self.FWHM * FWHM2sigma * constants.arcsec # radians

# calculate the UV taper from the FWHM size.
taper_2D = torch.exp(
-2 * np.pi**2 * (sigma**2 * self.u**2 + sigma**2 * self.v**2)
)

# apply taper to packed image
tapered_vis = vis_like * torch.broadcast_to(taper_2D, packed_cube.size())

# iFFT back, ignoring prefactors for round-trip
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
torch.max(convolved_packed_cube.imag), thresh
)

r_cube: torch.Tensor = convolved_packed_cube.real
return r_cube


class GaussConvCube(nn.Module):
r"""
Once instantiated, this convolutional layer is used to convolve the input cube with
a 2D Gaussian filter. The filter is the same for all channels in the input cube.
This convolutional layer will convolve the input cube with a 2D Gaussian kernel.
The filter is the same for all channels in the input cube.
Because the operation is carried out in the image domain, note that it may become
computationally prohibitive for large kernel sizes. In that case,
:class:`mpol.images.GaussConvFourier` may be preferred.
Parameters
----------
Expand All @@ -373,7 +213,8 @@ class GaussConvCube(nn.Module):
Omega: float, degrees
the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the North-South direction.
requires_grad : bool
keep kernel fixed
Should the kernel parameters be trainable? Most applications will want to use
`False`.
"""

def __init__(
Expand Down Expand Up @@ -488,6 +329,87 @@ def forward(self, sky_cube: torch.Tensor) -> torch.Tensor:
convolved_sky = self.m(sky_cube)
return convolved_sky

class GaussConvFourier(nn.Module):
r"""
This layer will convolve the input cube with a (potentially non-circular) Gaussian
beam, using a Fourier strategy.
The size of the beam is set upon initialization of the layer.
Parameters
----------
coords : :class:`mpol.coordinates.GridCoords`
an object instantiated from the GridCoords class, containing information about
the image `cell_size` and `npix`.
FWHM_maj: float, units of arcsec
the FWHH of the Gaussian along the major axis
FWHM_min: float, units of arcsec
the FWHM of the Gaussian along the minor axis
Omega: float, degrees
the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the North-South direction.
"""

def __init__(
self,
coords: GridCoords,
FWHM_maj: float,
FWHM_min: float,
Omega: float = 0) -> None:
super().__init__()

self.coords = coords
self.FWHM_maj = FWHM_maj
self.FWHM_min = FWHM_min
self.Omega = Omega

taper_2D = uv_gaussian_taper(self.coords, self.FWHM_maj, self.FWHM_min, self.Omega)

# store taper to register so it transfers to GPU
self.register_buffer("taper_2D", torch.tensor(taper_2D, dtype=torch.float32))

def forward(self, packed_cube):
r"""
Convolve a packed_cube image with a 2D Gaussian PSF. Operation is carried out
in the Fourier domain using a Gaussian taper.
Parameters
----------
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
Returns
-------
:class:`torch.Tensor`
The convolved cube in packed format.
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (
(npix_m == self.coords.npix) and (npix_l == self.coords.npix)
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
packed_cube.size(), self.coords.npix
)

# in FFT packed format
# we're round-tripping, so we can ignore prefactors for correctness
# calling this `vis_like`, since it's not actually the vis
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))

# apply taper to packed image
tapered_vis = vis_like * torch.broadcast_to(self.taper_2D, packed_cube.size())

# iFFT back, ignoring prefactors for round-trip
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
torch.max(convolved_packed_cube.imag), thresh
)

r_cube: torch.Tensor = convolved_packed_cube.real
return r_cube


class ImageCube(nn.Module):
r"""
Expand Down Expand Up @@ -616,7 +538,7 @@ def to_FITS(

def uv_gaussian_taper(
coords: GridCoords, FWHM_maj: float, FWHM_min: float, Omega: float
) -> torch.Tensor:
) -> npt.NDArray[np.floating[Any]]:
r"""
Compute a packed Gaussian taper in the Fourier domain, to multiply against a packed
visibility cube. While similar to :meth:`mpol.utils.fourier_gaussian_lambda_arcsec`,
Expand All @@ -636,7 +558,7 @@ def uv_gaussian_taper(
Returns
-------
:class:`torch.Tensor` , shape ``(npix, npix)``
:class:`np.ndarray` , shape ``(npix, npix)``
The Gaussian taper in packed format.
"""

Expand All @@ -654,71 +576,9 @@ def uv_gaussian_taper(
vp = u * np.sin(Omega_d) + v * np.cos(Omega_d)

# calculate the Fourier Gaussian
taper_2D: npt.NDArray[np.floating[Any]]
taper_2D = np.exp(-2 * np.pi**2 * (sigma_l**2 * up**2 + sigma_m**2 * vp**2))

# # the fourier_gaussian_lambda_arcsec routine assumes the amplitude
# # is 1.0 *in the image plane*. This is not the same as having an
# # amplitude 1.0 in the visibility plane, which is a requirement of a
# # flux-conserving taper. So we renormalize.
# taper_2D /= np.max(np.abs(taper_2D))

return torch.from_numpy(taper_2D)


def convolve_packed_cube(
packed_cube: torch.Tensor,
coords: GridCoords,
FWHM_maj: float,
FWHM_min: float,
Omega: float = 0,
) -> torch.Tensor:
r"""
Convolve an image cube with a 2D Gaussian PSF. Operation is carried out in the Fourier domain using a Gaussian taper.
Parameters
----------
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
coords: :class:`mpol.coordinates.GridCoords`
object indicating image and Fourier grid specifications.
FWHM_maj: float, units of arcsec
the FWHH of the Gaussian along the major axis
FWHM_min: float, units of arcsec
the FWHM of the Gaussian along the minor axis
Omega: float, degrees
the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the East-West direction.
# a flux-conserving taper must have an amplitude of 1 at the origin.

Returns
-------
:class:`torch.Tensor`
The convolved cube in packed format.
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (
(npix_m == coords.npix) and (npix_l == coords.npix)
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
packed_cube.size(), coords.npix
)

# in FFT packed format
# we're round-tripping, so we can ignore prefactors for correctness
# calling this `vis_like`, since it's not actually the vis
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))

taper_2D = uv_gaussian_taper(coords, FWHM_maj, FWHM_min, Omega)
# calculate taper on packed image
tapered_vis = vis_like * torch.broadcast_to(taper_2D, packed_cube.size())

# iFFT back, ignoring prefactors for round-trip
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
torch.max(convolved_packed_cube.imag), thresh
)

r_cube: torch.Tensor = convolved_packed_cube.real
return r_cube
return taper_2D
Loading

0 comments on commit 4c7b453

Please sign in to comment.