From 115707419b99d5c9e0d51ae32f45a502fa535d07 Mon Sep 17 00:00:00 2001 From: Ian Czekala Date: Wed, 27 Dec 2023 14:47:21 +0000 Subject: [PATCH] removed from_image_properties and updated tests and docs. --- docs/changelog.md | 2 +- docs/ci-tutorials/fakedata.md | 4 +- docs/ci-tutorials/gridder.md | 16 +-- docs/large-tutorials/HD143006_part_1.md | 7 +- src/mpol/datasets.py | 67 ---------- src/mpol/fourier.py | 66 ---------- src/mpol/gridding.py | 162 ++++++++++++------------ src/mpol/images.py | 89 ++++++------- src/mpol/precomposed.py | 5 - test/gridder_dataset_export_test.py | 6 +- test/gridder_gridding_test.py | 28 ++-- test/gridder_imager_test.py | 6 +- test/gridder_init_test.py | 6 +- test/images_test.py | 32 +---- 14 files changed, 166 insertions(+), 330 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 72cabc6b..f4229f85 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,7 +5,7 @@ ## v0.2.1 - *Placeholder* Planned changes described by Architecture GitHub Project. -- Removed convenience classmethods `from_image_properties` from across the code base. From [#233](https://github.com/MPoL-dev/MPoL/issues/233). The recommended workflow is to create a :class:`mpol.coordinates.GridCoords` object and pass that to instantiate these objects as needed. For nearly all but trivially short workflows, this simplifies the number of variables the user needs to keep track and pass around revealing the central role of the :class:`mpol.coordinates.GridCoords` object and its useful attributes for image extent, visibility extent, etc. Most importantly, this significantly reduces the size of the codebase and the burden to maintain and document multiple entry points to key `nn.modules`. We removed `from_image_properties` from +- Removed convenience classmethods `from_image_properties` from across the code base. From [#233](https://github.com/MPoL-dev/MPoL/issues/233). The recommended workflow is to create a :class:`mpol.coordinates.GridCoords` object and pass that to instantiate these objects as needed, rather than passing `cell_size` and `npix` separately. For nearly all but trivially short workflows, this simplifies the number of variables the user needs to keep track and pass around revealing the central role of the :class:`mpol.coordinates.GridCoords` object and its useful attributes for image extent, visibility extent, etc. Most importantly, this significantly reduces the size of the codebase and the burden to maintain, test, and document multiple entry points to key `nn.modules`. We removed `from_image_properties` from - :class:`mpol.datasets.GriddedDataset` - :class:`mpol.datasets.Dartboard` - :class:`mpol.fourier.NuFFT` diff --git a/docs/ci-tutorials/fakedata.md b/docs/ci-tutorials/fakedata.md index d35eb887..26d60ea7 100644 --- a/docs/ci-tutorials/fakedata.md +++ b/docs/ci-tutorials/fakedata.md @@ -253,7 +253,9 @@ img_tensor_packed = utils.sky_cube_to_packed_cube(img_tensor) ```{code-cell} ipython3 from mpol.images import ImageCube -image = ImageCube.from_image_properties(cell_size=cell_size, npix=npix, nchan=1, cube=img_tensor_packed) +from mpol import coordinates +coords = coordinates.GridCoords(cell_size=cell_size, npix=npix) +image = ImageCube(coords=coords, nchan=1, cube=img_tensor_packed) ``` If you want to double-check that the image was correctly inserted, you can do diff --git a/docs/ci-tutorials/gridder.md b/docs/ci-tutorials/gridder.md index e8624d02..b4ba4353 100644 --- a/docs/ci-tutorials/gridder.md +++ b/docs/ci-tutorials/gridder.md @@ -155,21 +155,7 @@ imager = gridding.DirtyImager( ) ``` -Instantiating the {class}`~mpol.gridding.DirtyImager` object attaches the {class}`~mpol.coordinates.GridCoords` object and the loose visibilities. There is also a convenience method to create the {class}`~mpol.coordinates.GridCoords` and {class}`~mpol.gridding.DirtyImager` object in one shot by - -```{code-cell} -imager = gridding.DirtyImager.from_image_properties( - cell_size=0.005, # [arcsec] - npix=800, - uu=uu, - vv=vv, - weight=weight, - data_re=data_re, - data_im=data_im, -) -``` - -if you don't want to specify your {class}`~mpol.coordinates.GridCoords` object separately. +Instantiating the {class}`~mpol.gridding.DirtyImager` object attaches the {class}`~mpol.coordinates.GridCoords` object and the loose visibilities. As we saw, the raw visibility dataset is a set of complex-valued Fourier samples. Our objective is to make images of the sky-brightness distribution and do astrophysics. We'll cover how to do this with MPoL and RML techniques in later tutorials, but it is possible to get a rough idea of the sky brightness by calculating the inverse Fourier transform of the visibility values. diff --git a/docs/large-tutorials/HD143006_part_1.md b/docs/large-tutorials/HD143006_part_1.md index debf213e..3924a0fd 100644 --- a/docs/large-tutorials/HD143006_part_1.md +++ b/docs/large-tutorials/HD143006_part_1.md @@ -150,11 +150,10 @@ The FITS image was a full 3000x3000 pixels. In general, it is good practice to s Since the DSHARP team has already checked there are no bright sub-mm sources in the FOV, we can save time and just make a smaller image corresponding to the protoplanetary emission. If `cell_size` is 0.003 arcseconds, `npix=512` pixels should be sufficient to make an image approximately 1.5 arcseconds on a side. Now, let's import the relevant MPoL routines and instantiate the Gridder. ```{code-cell} -from mpol import gridding +from mpol import coordinates, gridding -imager = gridding.DirtyImager.from_image_properties( - cell_size=cell_size, - npix=512, +coords = coordinates.GridCoords(cell_size=cell_size, npix=512) +imager = gridding.DirtyImager( uu=uu, vv=vv, weight=weight, diff --git a/src/mpol/datasets.py b/src/mpol/datasets.py index ec0bb623..29c90afb 100644 --- a/src/mpol/datasets.py +++ b/src/mpol/datasets.py @@ -73,46 +73,6 @@ def __init__( self.vis_indexed: torch.Tensor self.weight_indexed: torch.Tensor - @classmethod - def from_image_properties( - cls, - cell_size: float, - npix: int, - *, - vis_gridded: torch.Tensor, - weight_gridded: torch.Tensor, - mask: torch.Tensor, - nchan: int = 1, - ) -> GriddedDataset: - """ - Alternative method to instantiate a GriddedDataset object from cell_size - and npix. - - Parameters - ---------- - cell_size : float - the width of a pixel [arcseconds] - npix : int - the number of pixels per image side - vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128` - the gridded visibility data stored in a "packed" format (pre-shifted for fft) - weight_gridded : :class:`torch.Tensor` of :class:`torch.double` - the weights corresponding to the gridded visibility data, - also in a packed format - mask : :class:`torch.Tensor` of :class:`torch.bool` - a boolean mask to index the non-zero locations of ``vis_gridded`` and - ``weight_gridded`` in their packed format. - nchan : int - the number of channels in the image (default = 1). - """ - return cls( - coords=GridCoords(cell_size, npix), - vis_gridded=vis_gridded, - weight_gridded=weight_gridded, - mask=mask, - nchan=nchan, - ) - def add_mask( self, mask: ArrayLike, @@ -247,33 +207,6 @@ def cartesian_phis(self) -> NDArray[floating[Any]]: def q_max(self) -> float: return self.coords.q_max - @classmethod - def from_image_properties( - cls, - cell_size: float, - npix: int, - q_edges: NDArray[floating[Any]] | None = None, - phi_edges: NDArray[floating[Any]] | None = None, - ) -> Dartboard: - """Alternative method to instantiate a Dartboard object from cell_size - and npix. - - Args: - cell_size (float): the width of a pixel [arcseconds] - npix (int): the number of pixels per image side - q_edges (1D numpy array): an array of radial bin edges to set the - dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults - to 12 log-linearly radial bins stretching from 0 to the - :math:`q_\mathrm{max}` represented by ``coords``. - phi_edges (1D numpy array): an array of azimuthal bin edges to set the - dartboard cells in [radians], over the domain :math:`[0, \pi]`, which - is also implicitly mapped to the domain :math:`[-\pi, \pi]` to preserve - the Hermitian nature of the visibilities. If ``None``, defaults to 8 - equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`. - """ - coords = GridCoords(cell_size, npix) - return cls(coords, q_edges, phi_edges) - def get_polar_histogram( self, qs: NDArray[floating[Any]], phis: NDArray[floating[Any]] ) -> NDArray[floating[Any]]: diff --git a/src/mpol/fourier.py b/src/mpol/fourier.py index ade7e89b..706d3c67 100644 --- a/src/mpol/fourier.py +++ b/src/mpol/fourier.py @@ -49,33 +49,6 @@ def __init__(self, coords: GridCoords, persistent_vis: bool = False): self.register_buffer("vis", None, persistent=persistent_vis) self.vis: torch.Tensor - @classmethod - def from_image_properties( - cls, cell_size: float, npix: int, persistent_vis: bool = False - ) -> FourierCube: - """ - Alternative method for instantiating a FourierCube from ``cell_size`` and - ``npix`` - - Parameters - ---------- - cell_size : float - the width of an image-plane pixel [arcseconds] - npix : int) - the number of pixels per image side - persistent_vis : bool - should the visibility cube be stored as part of - the modules `state_dict`? If `True`, the state of the UV grid will be - stored. It is recommended to use `False` for most applications, since - the visibility cube will rarely be a direct parameter of the model. - - Returns - ------- - :class:`mpol.fourier.FourierCube` - """ - coords = GridCoords(cell_size, npix) - return cls(coords, persistent_vis) - def forward(self, cube: torch.Tensor) -> torch.Tensor: """ Perform the FFT of the image cube on each channel. @@ -314,32 +287,6 @@ def __init__( im_size=(self.coords.npix, self.coords.npix) ) - @classmethod - def from_image_properties( - cls, - cell_size: float, - npix: int, - nchan: int = 1, - ): - """ - Instantiate a :class:`mpol.fourier.NuFFT` object from image properties rather - than a :meth:`mpol.coordinates.GridCoords` instance. - - Args: - cell_size (float): the width of an image-plane pixel [arcseconds] - npix (int): the number of pixels per image side - nchan (int): the number of channels in the :class:`mpol.images.ImageCube`. - Default = 1. - - Returns: - an instance of the :class:`mpol.fourier.NuFFT` - """ - coords = GridCoords(cell_size, npix) - return cls( - coords, - nchan, - ) - def _klambda_to_radpix(self, klambda: torch.Tensor) -> torch.Tensor: """Convert a spatial frequency in units of klambda to 'radians/sky pixel,' using the pixel cell_size provided by ``self.coords.dl``. @@ -752,19 +699,6 @@ def __init__( self.real_interp_mat: torch.Tensor self.imag_interp_mat: torch.Tensor - @classmethod - def from_image_properties( - cls, - cell_size, - npix, - uu, - vv, - nchan=1, - sparse_matrices=True, - ): - coords = GridCoords(cell_size, npix) - return cls(coords, uu, vv, nchan, sparse_matrices) - def forward(self, cube): r""" Perform the FFT of the image cube for each channel and interpolate to the diff --git a/src/mpol/gridding.py b/src/mpol/gridding.py index d99c33f5..ea067788 100644 --- a/src/mpol/gridding.py +++ b/src/mpol/gridding.py @@ -6,7 +6,7 @@ import warnings -from typing import Any +from typing import Any, Sequence import numpy as np import numpy.typing as npt @@ -18,12 +18,12 @@ def _check_data_inputs_2d( - uu=npt.NDArray[np.floating[Any]] | None, - vv=npt.NDArray[np.floating[Any]] | None, - weight=npt.NDArray[np.floating[Any]] | None, - data_re=npt.NDArray[np.floating[Any]] | None, - data_im=npt.NDArray[np.floating[Any]] | None, -): + uu=npt.NDArray[np.floating[Any]], + vv=npt.NDArray[np.floating[Any]], + weight=npt.NDArray[np.floating[Any]], + data_re=npt.NDArray[np.floating[Any]], + data_im=npt.NDArray[np.floating[Any]], +) -> tuple[np.ndarray, ...]: """ Check that all data inputs are the same shape, the weights are positive, and the data_re and data_im are floats. @@ -67,7 +67,13 @@ def _check_data_inputs_2d( return uu, vv, weight, data_re, data_im -def verify_no_hermitian_pairs(uu, vv, data, test_vis=5, test_channel=0): +def verify_no_hermitian_pairs( + uu: npt.NDArray[np.floating[Any]], + vv: npt.NDArray[np.floating[Any]], + data: npt.NDArray[np.complexfloating[Any, Any]], + test_vis: int = 5, + test_channel: int = 0, +) -> bool: r""" Check that the dataset does not contain Hermitian pairs. Because the sky brightness :math:`I_\nu` is real, the visibility function :math:`\mathcal{V}` is Hermitian, @@ -153,27 +159,20 @@ def verify_no_hermitian_pairs(uu, vv, data, test_vis=5, test_channel=0): num_pairs += 1 if num_pairs == 0: - return + return True if num_pairs == test_vis: raise DataError( "Hermitian pairs were found in the data. Please provide data without" " Hermitian pairs." ) + return False else: raise DataError( f"{num_pairs} Hermitian pairs were found out of {test_vis} visibilities" " tested, dataset is inconsistent." ) - - # choose a uu, vv point, then see if the opposite value exists in the dataset - # if it does, then check that its visibility value is the complex conjugate - - # we could have a max threshold, i.e., like at least 5 need to exist to say - # the dataset has pairs - - # Subtract - return False + return False class GridderBase: @@ -218,13 +217,13 @@ class GridderBase: def __init__( self, - coords=None, - uu=None, - vv=None, - weight=None, - data_re=None, - data_im=None, - ): + coords=GridCoords, + uu=npt.NDArray[np.floating[Any]], + vv=npt.NDArray[np.floating[Any]], + weight=npt.NDArray[np.floating[Any]], + data_re=npt.NDArray[np.floating[Any]], + data_im=npt.NDArray[np.floating[Any]], + ) -> None: # check everything should be 2d, expand if not # also checks data does not contain Hermitian pairs uu, vv, weight, data_re, data_im = _check_data_inputs_2d( @@ -249,21 +248,7 @@ def __init__( # and register cell indices against data self._create_cell_indices() - @classmethod - def from_image_properties( - cls, - cell_size, - npix, - uu=None, - vv=None, - weight=None, - data_re=None, - data_im=None, - ) -> GridderBase: - coords = GridCoords(cell_size, npix) - return cls(coords, uu, vv, weight, data_re, data_im) - - def _create_cell_indices(self): + def _create_cell_indices(self) -> None: # figure out which visibility cell each datapoint lands in, so that # we can later assign it the appropriate robust weight for that cell # do this by calculating the nearest cell index [0, N] for all samples @@ -275,7 +260,12 @@ def _create_cell_indices(self): [np.digitize(v_chan, self.coords.v_edges) - 1 for v_chan in self.vv] ) - def _sum_cell_values_channel(self, uu, vv, values=None): + def _sum_cell_values_channel( + self, + uu: npt.NDArray[np.floating[Any]], + vv: npt.NDArray[np.floating[Any]], + values: npt.NDArray[np.floating[Any]] | None = None, + ) -> npt.NDArray[np.floating[Any]]: r""" Given a list of loose visibility points :math:`(u,v)` and their corresponding values :math:`x`, partition the points up into 2D :math:`u-v` cells defined by @@ -308,7 +298,7 @@ def _sum_cell_values_channel(self, uu, vv, values=None): cell quantities. """ - result = fast_hist.histogram2d( + result: npt.NDArray[np.floating[Any]] = fast_hist.histogram2d( vv, uu, bins=self.coords.ncell_u, @@ -322,7 +312,9 @@ def _sum_cell_values_channel(self, uu, vv, values=None): # only return the "H" value return result - def _sum_cell_values_cube(self, values=None): + def _sum_cell_values_cube( + self, values: npt.NDArray[np.floating[Any]] | Sequence[None] | None = None + ) -> npt.NDArray[np.floating[Any]]: r""" Perform the :func:`~mpol.gridding.DataAverager.sum_cell_values_channel` routine over all channels of the input visibilities. @@ -353,7 +345,9 @@ def _sum_cell_values_cube(self, values=None): return cube - def _extract_gridded_values_to_loose(self, gridded_quantity): + def _extract_gridded_values_to_loose( + self, gridded_quantity: npt.NDArray[np.floating[Any]] + ) -> npt.NDArray[np.floating[Any]]: r""" Extract the gridded cell quantity corresponding to each of the loose visibilities. @@ -374,7 +368,45 @@ def _extract_gridded_values_to_loose(self, gridded_quantity): ] ) - def _estimate_cell_standard_deviation(self): + def _grid_visibilities(self) -> None: + r""" + Average the loose data visibilities to the Fourier grid. + """ + + # create the cells as edges around the existing points + # note that at this stage, the UV grid is strictly increasing + # when in fact, later on, we'll need to fftshift for the FFT + cell_weight = self._sum_cell_values_cube(self.weight) + + # boolean index for cells that *contain* visibilities + mask = cell_weight > 0.0 + + # calculate the density weights under "uniform" + # the density weights have the same shape as the re, im samples. + # cell_weight is (nchan, ncell_v, ncell_u) + # self.index_v, self.index_u are (nchan, nvis) + # we want density_weights to be (nchan, nvis) + density_weight = 1 / self._extract_gridded_values_to_loose(cell_weight) + + # grid the reals and imaginaries separately + # outputs from _sum_cell_values_cube are *not* pre-packed + data_re_gridded = self._sum_cell_values_cube( + self.data_re * density_weight * self.weight + ) + + data_im_gridded = self._sum_cell_values_cube( + self.data_im * density_weight * self.weight + ) + + # store the pre-packed FFT products for access by outside routines + self.mask = np.fft.fftshift(mask, axes=(1, 2)) + self.data_re_gridded = np.fft.fftshift(data_re_gridded, axes=(1, 2)) + self.data_im_gridded = np.fft.fftshift(data_im_gridded, axes=(1, 2)) + self.vis_gridded = self.data_re_gridded + self.data_im_gridded * 1.0j + + def _estimate_cell_standard_deviation( + self, + ) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.floating[Any]]]: r""" Estimate the `standard deviation `__ of the real and imaginary @@ -462,7 +494,9 @@ def _estimate_cell_standard_deviation(self): return s_re, s_im - def _check_scatter_error(self, max_scatter=1.2): + def _check_scatter_error( + self, max_scatter: float = 1.2 + ) -> dict[str, bool | np.floating[Any]]: """ Checks/compares visibility scatter to a given threshold value ``max_scatter`` and raises an AssertionError if the median scatter across all cells exceeds @@ -568,42 +602,6 @@ class DataAverager(GridderBase): """ - def _grid_visibilities(self): - r""" - Average the loose data visibilities to the Fourier grid. - """ - - # create the cells as edges around the existing points - # note that at this stage, the UV grid is strictly increasing - # when in fact, later on, we'll need to fftshift for the FFT - cell_weight = self._sum_cell_values_cube(self.weight) - - # boolean index for cells that *contain* visibilities - mask = cell_weight > 0.0 - - # calculate the density weights under "uniform" - # the density weights have the same shape as the re, im samples. - # cell_weight is (nchan, ncell_v, ncell_u) - # self.index_v, self.index_u are (nchan, nvis) - # we want density_weights to be (nchan, nvis) - density_weight = 1 / self._extract_gridded_values_to_loose(cell_weight) - - # grid the reals and imaginaries separately - # outputs from _sum_cell_values_cube are *not* pre-packed - data_re_gridded = self._sum_cell_values_cube( - self.data_re * density_weight * self.weight - ) - - data_im_gridded = self._sum_cell_values_cube( - self.data_im * density_weight * self.weight - ) - - # store the pre-packed FFT products for access by outside routines - self.mask = np.fft.fftshift(mask, axes=(1, 2)) - self.data_re_gridded = np.fft.fftshift(data_re_gridded, axes=(1, 2)) - self.data_im_gridded = np.fft.fftshift(data_im_gridded, axes=(1, 2)) - self.vis_gridded = self.data_re_gridded + self.data_im_gridded * 1.0j - def _grid_weights(self): r""" Average the visibility weights to the Fourier grid contained in ``self.coords``, diff --git a/src/mpol/images.py b/src/mpol/images.py index 1b14b319..18133aa6 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -8,7 +8,7 @@ import torch.fft # to avoid conflicts with old torch.fft *function* from torch import nn -from typing import Callable +from typing import Any, Callable from mpol import utils from mpol.coordinates import GridCoords @@ -52,7 +52,7 @@ def __init__( nchan: int = 1, pixel_mapping: Callable[[torch.Tensor], torch.Tensor] | None = None, base_cube: torch.Tensor | None = None, - ): + ) -> None: super().__init__() self.coords = coords @@ -78,10 +78,10 @@ def __init__( self.base_cube = nn.Parameter(base_cube, requires_grad=True) if pixel_mapping is None: - self.pixel_mapping = torch.nn.Softplus() + self.pixel_mapping: Callable[ + [torch.Tensor], torch.Tensor + ] = torch.nn.Softplus() else: - # TODO assert that this is a PyTorch function (and not a numpy function, - # for example) self.pixel_mapping = pixel_mapping def forward(self) -> torch.Tensor: @@ -127,7 +127,7 @@ class HannConvCube(nn.Module): requires_grad (bool): keep kernel fixed """ - def __init__(self, nchan, requires_grad=False): + def __init__(self, nchan: int, requires_grad: bool = False) -> None: super().__init__() # simple convolutional filter operates on per-channel basis # 3x3 Hann filter @@ -159,7 +159,7 @@ def __init__(self, nchan, requires_grad=False): torch.zeros(nchan, dtype=torch.double), requires_grad=requires_grad ) - def forward(self, cube): + def forward(self, cube: torch.Tensor) -> torch.Tensor: r"""Args: cube (torch.double tensor, of shape ``(nchan, npix, npix)``): a prepacked image cube, for example, from ImageCube.forward() @@ -201,31 +201,33 @@ class ImageCube(nn.Module): since no transformations are applied to the ``cube`` tensor. The main purpose of the ImageCube layer is to provide useful functionality around the ``cube`` tensor, such as returning a sky_cube representation and providing FITS writing - functionility. In the case of ``passthrough==False``, the ImageCube layer also acts + functionality. In the case of ``passthrough==False``, the ImageCube layer also acts as a container for the trainable parameters. - Args: - cell_size (float): the width of a pixel [arcseconds] - npix (int): the number of pixels per image side - coords (GridCoords): an object already instantiated from the GridCoords class. - If providing this, cannot provide ``cell_size`` or ``npix``. - nchan (int): the number of channels in the image - passthrough (bool): if passthrough, assume ImageCube is just a layer as opposed - to parameter base. - cube (torch.double tensor, of shape ``(nchan, npix, npix)``): (optional) a - prepacked image cube to initialize the model with in units of - [:math:`\mathrm{Jy}\,\mathrm{arcsec}^{-2}`]. If None, assumes starting - ``cube`` is ``torch.zeros``. See :ref:`cube-orientation-label` for more - information on the expectations of the orientation of the input image. + Parameters + ---------- + coords : :class:`mpol.coordinates.GridCoords` + an object instantiated from the GridCoords class, containing information about + the image `cell_size` and `npix`. + nchan : int + the number of channels in the base cube. Default = 1. + passthrough : bool + if `True`, assume ImageCube is just a layer as opposed + to parameter base. + cube : :class:torch.Tensor of :class:torch.double, of shape ``(nchan, npix, npix)`` + a prepacked image cube to initialize the model with in units of + [:math:`\mathrm{Jy}\,\mathrm{arcsec}^{-2}`]. If None, assumes starting + ``cube`` is ``torch.zeros``. See :ref:`cube-orientation-label` for more + information on the expectations of the orientation of the input image. """ def __init__( self, - coords=None, - nchan=1, - passthrough=False, - cube=None, - ): + coords: GridCoords, + nchan: int = 1, + passthrough: bool = False, + cube: torch.Tensor | None = None, + ) -> None: super().__init__() self.coords = coords @@ -235,7 +237,7 @@ def __init__( if not self.passthrough: if cube is None: - self.cube = nn.Parameter( + self.cube : torch.nn.Parameter = nn.Parameter( torch.full( (self.nchan, self.coords.npix, self.coords.npix), fill_value=0.0, @@ -257,14 +259,7 @@ def __init__( # an initialization argument self.cube = None - @classmethod - def from_image_properties( - cls, cell_size, npix, nchan=1, passthrough=False, cube=None - ) -> ImageCube: - coords = GridCoords(cell_size, npix) - return cls(coords, nchan, passthrough, cube) - - def forward(self, cube=None): + def forward(self, cube: torch.Tensor | None = None) -> torch.Tensor: r""" If the ImageCube object was initialized with ``passthrough=True``, the ``cube`` argument is required. ``forward`` essentially just passes this on as an identity @@ -294,7 +289,7 @@ def forward(self, cube=None): return self.cube @property - def sky_cube(self): + def sky_cube(self) -> torch.Tensor: """ The image cube arranged as it would appear on the sky. @@ -305,7 +300,7 @@ def sky_cube(self): return utils.packed_cube_to_sky_cube(self.cube) @property - def flux(self): + def flux(self) -> torch.Tensor: """ The spatially-integrated flux of the image. Returns a 'spectrum' with the flux in each channel in units of Jy. @@ -318,7 +313,12 @@ def flux(self): # multiply by arcsec^2/pixel return self.coords.cell_size**2 * torch.sum(self.cube, dim=(1, 2)) - def to_FITS(self, fname="cube.fits", overwrite=False, header_kwargs=None): + def to_FITS( + self, + fname: str = "cube.fits", + overwrite: bool = False, + header_kwargs: dict | None = None, + ) -> None: """ Export the image cube to a FITS file. @@ -330,15 +330,10 @@ def to_FITS(self, fname="cube.fits", overwrite=False, header_kwargs=None): Returns: None """ - - try: - from astropy import wcs - from astropy.io import fits - except ImportError: - print( - "Please install the astropy package to use FITS export functionality." - ) - + + from astropy import wcs + from astropy.io import fits + w = wcs.WCS(naxis=2) w.wcs.crpix = np.array([1, 1]) diff --git a/src/mpol/precomposed.py b/src/mpol/precomposed.py index c7e3b539..35c2f6f6 100644 --- a/src/mpol/precomposed.py +++ b/src/mpol/precomposed.py @@ -57,11 +57,6 @@ def __init__( ) self.fcube = fourier.FourierCube(coords=self.coords) - @classmethod - def from_image_properties(cls, cell_size, npix, nchan, base_cube): - coords = GridCoords(cell_size, npix) - return cls(coords, nchan, base_cube) - def forward(self): r""" Feed forward to calculate the model visibilities. In this step, a diff --git a/test/gridder_dataset_export_test.py b/test/gridder_dataset_export_test.py index 6769b662..d7884d1d 100644 --- a/test/gridder_dataset_export_test.py +++ b/test/gridder_dataset_export_test.py @@ -10,9 +10,9 @@ def averager(mock_visibility_data): uu, vv, weight, data_re, data_im = mock_visibility_data - return gridding.DataAverager.from_image_properties( - cell_size=0.005, - npix=800, + coords = coordinates.GridCoords(cell_size=0.005, npix=800) + return gridding.DataAverager( + coords=coords, uu=uu, vv=vv, weight=weight, diff --git a/test/gridder_gridding_test.py b/test/gridder_gridding_test.py index aa6db553..ce59e009 100644 --- a/test/gridder_gridding_test.py +++ b/test/gridder_gridding_test.py @@ -14,9 +14,13 @@ def test_average_cont(mock_visibility_data_cont): """ uu, vv, weight, data_re, data_im = mock_visibility_data_cont - averager = gridding.DataAverager.from_image_properties( + coords = coordinates.GridCoords( cell_size=0.005, npix=800, + ) + + averager = gridding.DataAverager( + coords=coords, uu=uu, vv=vv, weight=weight, @@ -57,7 +61,10 @@ def test_uniform_ones(mock_visibility_data, tmp_path): averager._grid_visibilities() im = plt.imshow( - averager.ground_cube[4].real, origin="lower", extent=averager.coords.vis_ext, interpolation="none" + averager.ground_cube[4].real, + origin="lower", + extent=averager.coords.vis_ext, + interpolation="none", ) plt.colorbar(im) plt.savefig(tmp_path / "gridded_re.png", dpi=300) @@ -65,20 +72,23 @@ def test_uniform_ones(mock_visibility_data, tmp_path): plt.figure() im2 = plt.imshow( - averager.ground_cube[4].imag, origin="lower", extent=averager.coords.vis_ext, interpolation="none" + averager.ground_cube[4].imag, + origin="lower", + extent=averager.coords.vis_ext, + interpolation="none", ) plt.colorbar(im2) plt.savefig(tmp_path / "gridded_im.png", dpi=300) plt.close("all") - # if the gridding worked, + # if the gridding worked, # cells with no data should be 0 assert averager.data_re_gridded[~averager.mask] == pytest.approx(0) - + # and cells with data should have real values approximately 1 assert averager.data_re_gridded[averager.mask] == pytest.approx(1) - + # and imaginary values approximately 0 everywhere assert averager.data_im_gridded == pytest.approx(0) @@ -91,9 +101,9 @@ def test_weight_gridding(mock_visibility_data): data_re = np.ones_like(uu) data_im = np.ones_like(uu) - averager = gridding.DataAverager.from_image_properties( - cell_size=0.005, - npix=800, + coords = coordinates.GridCoords(cell_size=0.005, npix=800) + averager = gridding.DataAverager( + coords=coords, uu=uu, vv=vv, weight=weight, diff --git a/test/gridder_imager_test.py b/test/gridder_imager_test.py index 1a08098c..96346ed9 100644 --- a/test/gridder_imager_test.py +++ b/test/gridder_imager_test.py @@ -14,9 +14,9 @@ def imager(mock_visibility_data): uu, vv, weight, data_re, data_im = mock_visibility_data - return gridding.DirtyImager.from_image_properties( - cell_size=0.005, - npix=800, + coords = coordinates.GridCoords(cell_size=0.005, npix=800) + return gridding.DirtyImager( + coords=coords, uu=uu, vv=vv, weight=weight, diff --git a/test/gridder_init_test.py b/test/gridder_init_test.py index 7be0f2ac..464e81f7 100644 --- a/test/gridder_init_test.py +++ b/test/gridder_init_test.py @@ -31,9 +31,11 @@ def test_hermitian_pairs(mock_visibility_data): def test_averager_instantiate_cell_npix(mock_visibility_data): uu, vv, weight, data_re, data_im = mock_visibility_data - gridding.DataAverager.from_image_properties( + coords = coordinates.GridCoords( cell_size=0.005, - npix=800, + npix=800 + ) + gridding.DataAverager(coords=coords, uu=uu, vv=vv, weight=weight, diff --git a/test/images_test.py b/test/images_test.py index 882cb257..eab22fca 100644 --- a/test/images_test.py +++ b/test/images_test.py @@ -3,37 +3,18 @@ import torch from astropy.io import fits -from mpol import images, utils +from mpol import coordinates, images, utils from mpol.constants import * - -def test_odd_npix(): - expected_error_message = "Image must have an even number of pixels." - - with pytest.raises(ValueError, match=expected_error_message): - images.BaseCube.from_image_properties(npix=853, nchan=30, cell_size=0.015) - - with pytest.raises(ValueError, match=expected_error_message): - images.ImageCube.from_image_properties(npix=853, nchan=30, cell_size=0.015) - - -def test_negative_cell_size(): - expected_error_message = "cell_size must be a positive real number." - - with pytest.raises(ValueError, match=expected_error_message): - images.BaseCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015) - - with pytest.raises(ValueError, match=expected_error_message): - images.ImageCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015) - - def test_single_chan(): - im = images.ImageCube.from_image_properties(cell_size=0.015, npix=800) + coords = coordinates.GridCoords(cell_size=0.015, npix=800) + im = images.ImageCube(coords=coords) assert im.nchan == 1 def test_basecube_grad(): - bcube = images.BaseCube.from_image_properties(npix=800, cell_size=0.015) + coords = coordinates.GridCoords(cell_size=0.015, npix=800) + bcube = images.BaseCube(coords=coords) loss = torch.sum(bcube()) loss.backward() @@ -189,7 +170,8 @@ def test_multi_chan_conv(coords, tmp_path): conv_layer(test_cube) + def test_image_flux(coords): nchan = 20 - im = images.ImageCube(coords=coords, nchan=nchan) + im = images.ImageCube(coords=coords, nchan=nchan) assert im.flux.size()[0] == nchan