diff --git a/src/mpol/gridding.py b/src/mpol/gridding.py index ae75fb15..facb234b 100644 --- a/src/mpol/gridding.py +++ b/src/mpol/gridding.py @@ -14,7 +14,7 @@ from .datasets import GriddedDataset -def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=None): +def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=None, freq=None): """ Check that all data inputs are the same shape, the weights are positive, and the data_re and data_im are floats. @@ -34,6 +34,9 @@ def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=N "All dataset inputs must be the same input shape and size." ) + if freq is not None: # TODO:change to wrongdimensionerror + assert len(uu) == len(freq), "uu must have same number of channels as freq array." + if np.any(weight <= 0.0): raise ValueError("Not all thermal weights are positive, check inputs.") @@ -53,8 +56,7 @@ def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=N # check to see that uu, vv and data do not contain Hermitian pairs verify_no_hermitian_pairs(uu, vv, data_re + 1.0j * data_im) - return uu, vv, weight, data_re, data_im - + return uu, vv, weight, data_re, data_im, freq def verify_no_hermitian_pairs(uu, vv, data, test_vis=5, test_channel=0): r""" @@ -144,6 +146,33 @@ def verify_no_hermitian_pairs(uu, vv, data, test_vis=5, test_channel=0): return False +def _check_freq_1d(freq=None): + """ + Check that the frequency input array contains only positive floats. + + If the user supplied a float, convert to a 1D array. If no frequency array + was supplied, simply skip. + + """ + if freq is None: + return freq + + assert ( + np.isscalar(freq) or freq.ndim == 1 + ), "Input data vectors should be either None, scalar, or 1D array." + + assert np.all(freq > 0.0), "Not all frequencies are positive, check inputs." + + if np.isscalar(freq): + freq = np.atleast_1d(freq) + + assert (freq.dtype == np.single) or ( + freq.dtype == np.double + ), "freq should be type single or double" + + return freq + + class GridderBase: r""" This class is not designed to be used directly, but rather to be subclassed. @@ -172,13 +201,18 @@ def __init__( weight=None, data_re=None, data_im=None, + chan_freq=None, ): + + # check frequency array is 1d or None, expand if not + chan_freq = _check_freq_1d(chan_freq) + # check everything should be 2d, expand if not # also checks data does not contain Hermitian pairs - uu, vv, weight, data_re, data_im = _check_data_inputs_2d( - uu, vv, weight, data_re, data_im + uu, vv, weight, data_re, data_im, chan_freq = _check_data_inputs_2d( + uu, vv, weight, data_re, data_im, chan_freq ) - + # setup the coordinates object self.coords = coords self.nchan = len(uu) @@ -193,6 +227,7 @@ def __init__( self.weight = weight self.data_re = data_re self.data_im = data_im + self.chan_freq = chan_freq # and register cell indices against data self._create_cell_indices() @@ -211,6 +246,7 @@ def from_image_properties( coords = GridCoords(cell_size, npix) return cls(coords, uu, vv, weight, data_re, data_im) + def _create_cell_indices(self): # figure out which visibility cell each datapoint lands in, so that # we can later assign it the appropriate robust weight for that cell @@ -586,7 +622,7 @@ def __init__( ): # check everything should be 2d, expand if not # also checks data does not contain Hermitian pairs - uu, vv, weight, data_re, data_im = _check_data_inputs_2d( + uu, vv, weight, data_re, data_im, freq = _check_data_inputs_2d( uu, vv, weight, data_re, data_im ) diff --git a/src/mpol/images.py b/src/mpol/images.py index 0185766e..509c518b 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np +from scipy.special import j1 import torch import torch.fft # to avoid conflicts with old torch.fft *function* from torch import nn @@ -11,6 +12,7 @@ from .coordinates import GridCoords + class BaseCube(nn.Module): r""" A base cube of the same dimensions as the image cube. Designed to use a pixel mapping function :math:`f_\mathrm{map}` from the base cube values to the ImageCube domain. diff --git a/src/mpol/precomposed.py b/src/mpol/precomposed.py index 149fae73..f07af11e 100644 --- a/src/mpol/precomposed.py +++ b/src/mpol/precomposed.py @@ -2,7 +2,7 @@ from mpol.coordinates import GridCoords -from . import fourier, images +from . import fourier, images, primary_beam class SimpleNet(torch.nn.Module): @@ -35,7 +35,12 @@ def __init__( coords=None, nchan=1, base_cube=None, + chan_freqs=None, + dish_type=None, + dish_radius=None, + **dish_kwargs, ): + super().__init__() self.coords = coords @@ -50,12 +55,23 @@ def __init__( self.icube = images.ImageCube( coords=self.coords, nchan=self.nchan, passthrough=True ) + + self.pbcube = primary_beam.PrimaryBeamCube( + coords = self.coords, + nchan=self.nchan, + chan_freqs=chan_freqs, + dish_type=dish_type, + dish_radius=dish_radius, + **dish_kwargs + ) self.fcube = fourier.FourierCube(coords=self.coords) + @classmethod def from_image_properties(cls, cell_size, npix, nchan, base_cube): coords = GridCoords(cell_size, npix) return cls(coords, nchan, base_cube) + def forward(self): r""" @@ -66,5 +82,7 @@ def forward(self): x = self.bcube() x = self.conv_layer(x) x = self.icube(x) + x = self.pbcube(x) vis = self.fcube(x) + return vis diff --git a/src/mpol/primary_beam.py b/src/mpol/primary_beam.py new file mode 100644 index 00000000..255c2916 --- /dev/null +++ b/src/mpol/primary_beam.py @@ -0,0 +1,184 @@ +r"""The ``primary_beam`` module provides the core functionality of MPoL via :class:`mpol.fourier.PrimaryBeamCube`.""" + +from __future__ import annotations + +import numpy as np +import torch +import torch.fft # to avoid conflicts with old torch.fft *function* +import torchkbnufft +from torch import nn + +from . import utils +from .coordinates import GridCoords + +from .gridding import _check_freq_1d + +class PrimaryBeamCube(nn.Module): + r""" + A ImageCube representing the primary beam of a described dish type. Currently can correct for a + uniform or center-obscured dish. The forward() method multiplies an image cube by this primary beam mask. + + Args: + cell_size (float): the width of a pixel [arcseconds] + npix (int): the number of pixels per image side + coords (GridCoords): an object already instantiated from the GridCoords class. If providing this, cannot provide ``cell_size`` or ``npix``. + nchan (int): the number of channels in the image + dish_type (string): the type of dish to correct for. Either 'uniform' or 'obscured'. + dish_radius (float): the radius of the dish (in meters) + dish_kwargs (dict): any additional arguments needed for special dish types. Currently only uses: + dish_obscured_radius (float): the radius of the obscured portion of the dish + """ + def __init__( + self, + coords, + nchan=1, + chan_freqs=None, + dish_type=None, + dish_radius=None, + **dish_kwargs, + ): + super().__init__() + + #_setup_coords(self, cell_size, npix, coords, nchan) TODO: update this + + _check_freq_1d(chan_freqs) + assert (chan_freqs is None) or (len(chan_freqs) == nchan), "Length of chan_freqs must be equal to nchan" + + assert (dish_type is None) or (dish_type in ["uniform", "obscured"]), "Provided dish_type must be 'uniform' or 'obscured'" + + self.coords = coords + self.nchan = nchan + + self.default_mask = nn.Parameter( + torch.full( + (self.nchan, self.coords.npix, self.coords.npix), + fill_value=1.0, + requires_grad=False, + dtype=torch.double, + ) + ) + + if dish_type is None: + self.pb_mask = self.default_mask + elif dish_type == "uniform": + self.pb_mask = self.uniform_mask(chan_freqs, dish_radius) + elif dish_type == "obscured": + self.pb_mask = self.obscured_mask(chan_freqs, dish_radius, **dish_kwargs) + + @classmethod + def from_image_properties( + cls, cell_size, npix, nchan=1, + chan_freqs=None, dish_type=None, + dish_radius=None, **dish_kwargs + ) -> ImageCube: + coords = GridCoords(cell_size, npix) + return cls(coords, nchan, chan_freqs, dish_type, dish_radius, **dish_kwargs) + + def forward(self, cube): + r"""Args: + cube (torch.double tensor, of shape ``(nchan, npix, npix)``): a prepacked image cube, for example, from ImageCube.forward() + + Returns: + (torch.complex tensor, of shape ``(nchan, npix, npix)``): the FFT of the image cube, in packed format. + """ + return torch.mul(self.pb_mask, cube) + + + def uniform_mask(self, chan_freqs, dish_radius): + r""" + Generates airy disk primary beam correction mask. + """ + assert dish_radius > 0., "Dish radius must be positive" + ratio = 2. * dish_radius * np.array([[chan_freqs]]).T / 2.998e8 + + ratio_cube = np.tile(ratio,(1,self.coords.npix,self.coords.npix)) + r_2D = np.sqrt(self.coords.packed_x_centers_2D**2 + self.coords.packed_y_centers_2D**2) # arcsec + r_2D_rads = r_2D * np.pi / 180. / 60. / 60. # radians + r_cube = np.tile(r_2D_rads,(self.nchan,1,1)) + + r_normed_cube = np.pi * r_cube * ratio_cube + + mask = np.where(r_normed_cube > 0., + (2. * j1(r_normed_cube) / r_normed_cube)**2, + 1.) + return torch.tensor(mask) + + + def obscured_mask(self, chan_freqs, dish_radius, dish_obscured_radius=None, **extra_kwargs): + r""" + Generates airy disk primary beam correction mask. + """ + assert dish_obscured_radius is not None, "Obscured dish requires kwarg 'dish_obscured_radius'" + assert dish_radius > 0., "Dish radius must be positive" + assert dish_obscured_radius > 0., "Obscured dish radius must be positive" + assert dish_radius > dish_obscured_radius, "Primary dish radius must be greater than obscured radius" + + ratio = 2. * dish_radius * np.array([[chan_freqs]]).T / 2.998e8 + ratio_cube = np.tile(ratio,(1,self.coords.npix,self.coords.npix)) + r_2D = np.sqrt(self.coords.packed_x_centers_2D**2 + self.coords.packed_y_centers_2D**2) # arcsec + r_2D_rads = r_2D * np.pi / 180. / 60. / 60. # radians + r_cube = np.tile(r_2D_rads,(self.nchan,1,1)) + + eps = dish_obscured_radius / dish_radius + r_normed_cube = np.pi * r_cube * ratio_cube + + norm_factor = (1.-eps**2)**2 + mask = np.where(r_normed_cube > 0., + (j1(r_normed_cube) / r_normed_cube + - eps*j1(eps*r_normed_cube) / r_normed_cube)**2 / norm_factor, + 1.) + return torch.tensor(mask) + + @property + def sky_cube(self): + """ + The primary beam mask arranged as it would appear on the sky. + + Returns: + torch.double : 3D image cube of shape ``(nchan, npix, npix)`` + + """ + return utils.packed_cube_to_sky_cube(self.pb_mask) + + def to_FITS(self, fname="cube.fits", overwrite=False, header_kwargs=None): + """ + Export the primary beam cube to a FITS file. + + Args: + fname (str): the name of the FITS file to export to. + overwrite (bool): if the file already exists, overwrite? + header_kwargs (dict): Extra keyword arguments to write to the FITS header. + + Returns: + None + """ + + try: + from astropy import wcs + from astropy.io import fits + except ImportError: + print( + "Please install the astropy package to use FITS export functionality." + ) + + w = wcs.WCS(naxis=2) + + w.wcs.crpix = np.array([1, 1]) + w.wcs.cdelt = ( + np.array([self.coords.cell_size, self.coords.cell_size]) / 3600 + ) # decimal degrees + w.wcs.ctype = ["RA---TAN", "DEC--TAN"] + + header = w.to_header() + + # add in the kwargs to the header + if header_kwargs is not None: + for k, v in header_kwargs.items(): + header[k] = v + + hdu = fits.PrimaryHDU(self.pb_mask.detach().cpu().numpy(), header=header) + + hdul = fits.HDUList([hdu]) + hdul.writeto(fname, overwrite=overwrite) + + hdul.close() \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 2686b89c..51823962 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pytest from astropy.utils.data import download_file +import torch from mpol import coordinates, gridding @@ -53,6 +54,14 @@ def coords(): return coordinates.GridCoords(cell_size=0.005, npix=800) +@pytest.fixture +def unit_cube(coords): + nchan = 4 + input_cube = torch.full( + (nchan, coords.npix, coords.npix), fill_value=1.0, dtype=torch.double + ) + return input_cube + @pytest.fixture def averager(mock_visibility_data, coords): uu, vv, weight, data_re, data_im = mock_visibility_data diff --git a/test/primary_beam_test.py b/test/primary_beam_test.py new file mode 100644 index 00000000..dcea5a22 --- /dev/null +++ b/test/primary_beam_test.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +import torch +from pytest import approx + +from mpol import primary_beam, images, utils +from mpol.constants import * + +def test_no_dish_correction(coords, unit_cube): + # Tests layer when no PB correction is applied (passthrough layer) + pblayer = primary_beam.PrimaryBeamCube(coords=coords) + out_cube = pblayer(unit_cube) + + assert torch.equal(unit_cube, out_cube) + + + \ No newline at end of file