diff --git a/docs/changelog.md b/docs/changelog.md index cd420c79..970443b7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -3,7 +3,7 @@ # Changelog ## v0.3.0 - +- removed explicit type declarations in base MPoL modules. Previously, core representations were set to be in `float64` or `complex128`. Now, core MPoL representations (e.g., {class}`mpol.images.BaseCube`) will follow the [default tensor type](https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html), which is commonly `torch.float32`. If you want your model to run fully in `float32` or `complex64`, then be sure that your data is also in these formats, since otherwise PyTorch will promote downstream tensors as needed. Fully `float32` or `complex64` models should be able to run on Apple MPS [#254](https://github.com/MPoL-dev/MPoL/issues/254) - added {meth}`mpol.utils.convolve_packed_cube` method to convolve a 3D packed image cube with a 2D Gaussian. You can specify major axis, minor axis, and rotation angle. - added {meth}`mpol.utils.uv_gaussian_taper` to calculate a Gaussian tapering window in the visibility plane. - added the `vis_ext_Mlam` instance attribute to {class}`mpol.coordinates.GridCoords` for convenience plotting of visibility grids with axes labels in units of M$\lambda$. diff --git a/pyproject.toml b/pyproject.toml index 2a2cf976..3f04b169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,7 @@ target-version = "py310" line-length = 88 # will enable after sorting module locations # select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"] -ignore = [ +lint.ignore = [ "E741", # Allow ambiguous variable names "PLR0911", # Allow many return statements "PLR0913", # Allow many arguments to functions diff --git a/src/mpol/coordinates.py b/src/mpol/coordinates.py index d4017a81..04b3af01 100644 --- a/src/mpol/coordinates.py +++ b/src/mpol/coordinates.py @@ -307,10 +307,10 @@ def check_data_fit( Parameters ---------- - uu : :class:`torch.Tensor` of `torch.double` + uu : :class:`torch.Tensor` u spatial frequency coordinates. Units of [:math:`\lambda`] - vv : :class:`torch.Tensor` of `torch.double` + vv : :class:`torch.Tensor` v spatial frequency coordinates. Units of [:math:`\lambda`] diff --git a/src/mpol/datasets.py b/src/mpol/datasets.py index 96f44d9c..b9864864 100644 --- a/src/mpol/datasets.py +++ b/src/mpol/datasets.py @@ -19,7 +19,7 @@ class GriddedDataset(torch.nn.Module): If providing this, cannot provide ``cell_size`` or ``npix``. vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128` the gridded visibility data stored in a "packed" format (pre-shifted for fft) - weight_gridded : :class:`torch.Tensor` of :class:`torch.double` + weight_gridded : :class:`torch.Tensor` the weights corresponding to the gridded visibility data, also in a packed format mask : :class:`torch.Tensor` of :class:`torch.bool` diff --git a/src/mpol/fourier.py b/src/mpol/fourier.py index e583e6ae..82e63795 100644 --- a/src/mpol/fourier.py +++ b/src/mpol/fourier.py @@ -46,13 +46,13 @@ def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: Parameters ---------- - cube : :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)`` + cube : :class:`torch.Tensor` of shape ``(nchan, npix, npix)`` A 'packed' tensor. For example, an image cube from :meth:`mpol.images.ImageCube.forward` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``. + :class:`torch.Tensor` of shape ``(nchan, npix, npix)``. The FFT of the image cube, in packed format. """ @@ -89,7 +89,7 @@ def ground_amp(self) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)`` + :class:`torch.Tensor` of shape ``(nchan, npix, npix)`` amplitude cube in 'ground' format. """ return torch.abs(self.ground_vis) @@ -102,7 +102,7 @@ def ground_phase(self) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)`` + :class:`torch.Tensor` of shape ``(nchan, npix, npix)`` phase cube in 'ground' format (:math:`[-\pi,\pi)`). """ return torch.angle(self.ground_vis) @@ -275,14 +275,14 @@ def forward( Parameters ---------- - packed_cube : :class:`torch.Tensor` of :class:`torch.double` + packed_cube : :class:`torch.Tensor` shape ``(nchan, npix, npix)``). The cube should be a "prepacked" image cube, for example, from :meth:`mpol.images.ImageCube.forward` - uu : :class:`torch.Tensor` of :class:`torch.double` + uu : :class:`torch.Tensor` 2D array of the u (East-West) spatial frequency coordinate [:math:`\lambda`] of shape ``(nchan, npix)`` - vv : :class:`torch.Tensor` of :class:`torch.double` + vv : :class:`torch.Tensor` 2D array of the v (North-South) spatial frequency coordinate [:math:`\lambda`] (must be the same shape as uu) sparse_matrices : bool @@ -368,7 +368,7 @@ def forward( shifted = torch.fft.fftshift(packed_cube, dim=(1, 2)) # convert the cube to a complex type, since this is required by TorchKbNufft - complexed = shifted.type(torch.complex128) + complexed = shifted + 0j k_traj = self._assemble_ktraj(uu, vv) @@ -498,7 +498,7 @@ def __init__( self.real_interp_mat: torch.Tensor self.imag_interp_mat: torch.Tensor - def forward(self, cube): + def forward(self, packed_cube): r""" Perform the FFT of the image cube for each channel and interpolate to the ``uu`` and ``vv`` points set at layer initialization. This call should @@ -506,9 +506,10 @@ def forward(self, cube): ``uu`` and ``vv`` points. Args: - cube (torch.double tensor): of shape ``(nchan, npix, npix)``). The cube - should be a "prepacked" image cube, for example, from - :meth:`mpol.images.ImageCube.forward` + packed_cube : :class:`torch.Tensor` + shape ``(nchan, npix, npix)``). The cube + should be a "prepacked" image cube, for example, + from :meth:`mpol.images.ImageCube.forward` Returns: torch.complex tensor: of shape ``(nchan, nvis)``, Fourier samples evaluated @@ -517,17 +518,18 @@ def forward(self, cube): # make sure that the nchan assumptions for the ImageCube and the NuFFT # setup are the same - if cube.size(0) != self.nchan: + if packed_cube.size(0) != self.nchan: raise DimensionMismatchError( - f"nchan of ImageCube ({cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})" + f"nchan of ImageCube ({packed_cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})" ) # "unpack" the cube, but leave it flipped # NuFFT routine expects a "normal" cube, not an fftshifted one - shifted = torch.fft.fftshift(cube, dim=(1, 2)) + shifted = torch.fft.fftshift(packed_cube, dim=(1, 2)) # convert the cube to a complex type, since this is required by TorchKbNufft - complexed = shifted.type(torch.complex128) + complexed = shifted + 0j + # Consider how the similarity of the spatial frequency samples should be # treated. We already took care of this on the k_traj side, since we set diff --git a/src/mpol/geometry.py b/src/mpol/geometry.py index e4be2244..8158cd28 100644 --- a/src/mpol/geometry.py +++ b/src/mpol/geometry.py @@ -30,9 +30,9 @@ def flat_to_observer( Parameters ---------- - x : :class:`torch.Tensor` of :class:`torch.double` + x : :class:`torch.Tensor` A tensor representing the x coordinate in the plane of the orbit. - y : :class:`torch.Tensor` of :class:`torch.double` + y : :class:`torch.Tensor` A tensor representing the y coordinate in the plane of the orbit. omega : float Argument of periastron [radians]. Default 0.0. @@ -43,7 +43,7 @@ def flat_to_observer( Returns ------- - 2-tuple of :class:`torch.Tensor` of :class:`torch.double` + 2-tuple of :class:`torch.Tensor` Two tensors representing ``(X, Y)`` in the observer frame. """ # Rotation matrices result in a *clockwise* rotation of the axes, @@ -100,9 +100,9 @@ def observer_to_flat( Parameters ---------- - X : :class:`torch.Tensor` of :class:`torch.double` + X : :class:`torch.Tensor` A tensor representing the x coordinate in the plane of the sky. - Y : :class:`torch.Tensor` of :class:`torch.double` + Y : :class:`torch.Tensor` A tensor representing the y coordinate in the plane of the sky. omega : float A tensor representing an argument of periastron [radians] Default 0.0. @@ -114,7 +114,7 @@ def observer_to_flat( Returns ------- - 2-tuple of :class:`torch.Tensor` of :class:`torch.double` + 2-tuple of :class:`torch.Tensor` Two tensors representing ``(x, y)`` in the flat frame. """ # Rotation matrices result in a *clockwise* rotation of the axes, diff --git a/src/mpol/images.py b/src/mpol/images.py index 9c91660e..e49fbff2 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -67,7 +67,6 @@ def __init__( * torch.ones( (self.nchan, self.coords.npix, self.coords.npix), requires_grad=True, - dtype=torch.double, ) ) @@ -146,7 +145,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None: # bias has shape (nchan) # build out the discretely-sampled Hann kernel - spec = torch.tensor([0.25, 0.5, 0.25], dtype=torch.double) + spec = torch.tensor([0.25, 0.5, 0.25]) nugget = torch.outer(spec, spec) # shape (3,3) 2D Hann kernel exp = torch.unsqueeze(torch.unsqueeze(nugget, 0), 0) # shape (1, 1, 3, 3) weight = exp.repeat(nchan, 1, 1, 1) # shape (nchan, 1, 3, 3) @@ -158,7 +157,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None: # set the bias to zero self.m.bias = nn.Parameter( - torch.zeros(nchan, dtype=torch.double), requires_grad=requires_grad + torch.zeros(nchan), requires_grad=requires_grad ) def forward(self, cube: torch.Tensor) -> torch.Tensor: @@ -239,7 +238,7 @@ def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` type + :class:`torch.Tensor` type tensor of shape ``(nchan, npix, npix)``), same as `cube` """ self.packed_cube = packed_cube @@ -337,7 +336,7 @@ def uv_gaussian_taper( Returns ------- - :class:`torch.Tensor` of :class:`torch.double`, shape ``(npix, npix)`` + :class:`torch.Tensor` , shape ``(npix, npix)`` """ # convert FWHM to sigma and to radians @@ -379,7 +378,7 @@ def convolve_packed_cube( Parameters ---------- - packed_cube : :class:`torch.Tensor` of :class:`torch.double` type + packed_cube : :class:`torch.Tensor` type shape ``(nchan, npix, npix)`` image cube in packed format. coords: :class:`mpol.coordinates.GridCoords` object indicating image and Fourier grid specifications. @@ -392,7 +391,7 @@ def convolve_packed_cube( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` """ nchan, npix_m, npix_l = packed_cube.size() assert (npix_m == coords.npix) and ( @@ -414,7 +413,7 @@ def convolve_packed_cube( convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2)) # assert imaginaries are effectively zero, otherwise something went wrong - thresh = 1e-10 + thresh = 1e-7 assert ( torch.max(convolved_packed_cube.imag) < thresh ), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format( diff --git a/src/mpol/losses.py b/src/mpol/losses.py index e38d67ee..00e4c1e6 100644 --- a/src/mpol/losses.py +++ b/src/mpol/losses.py @@ -30,12 +30,12 @@ def _chi_squared( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\chi^2` likelihood, summed over all dimensions of input array. """ @@ -79,12 +79,12 @@ def r_chi_squared( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\chi^2_\mathrm{R}`, summed over all dimensions of input array. """ @@ -116,7 +116,7 @@ def r_chi_squared_gridded( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\chi^2_\mathrm{R}` value summed over all input dimensions """ model_vis = griddedDataset(modelVisibilityCube) @@ -166,12 +166,12 @@ def log_likelihood( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex128` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\ln\mathcal{L}` log likelihood, summed over all dimensions of input array. """ @@ -208,7 +208,7 @@ def log_likelihood_gridded( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\ln\mathcal{L}` value, summed over all dimensions of input data. """ @@ -247,12 +247,12 @@ def neg_log_likelihood_avg( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the average of the negative log likelihood, summed over all dimensions of input array. """ @@ -275,9 +275,9 @@ def entropy( Parameters ---------- - cube : :class:`torch.Tensor` of :class:`torch.double` + cube : :class:`torch.Tensor` pixel values must be positive :math:`I_i > 0` for all :math:`i` - prior_intensity : :class:`torch.Tensor` of :class:`torch.double` + prior_intensity : :class:`torch.Tensor` the prior value :math:`p` to calculate entropy against. Tensors of any shape are allowed so long as they will broadcast to the shape of the cube under division (`/`). @@ -287,7 +287,7 @@ def entropy( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` entropy loss """ # check to make sure image is positive, otherwise raise an error @@ -313,7 +313,7 @@ def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Parameters ---------- - sky_cube: 3D :class:`torch.Tensor` of :class:`torch.double` + sky_cube: 3D :class:`torch.Tensor` the image cube array :math:`I_{lmv}`, where :math:`l` is R.A. in :math:`ndim=3`, :math:`m` is DEC in :math:`ndim=2`, and :math:`v` is the channel (velocity or frequency) dimension in @@ -325,7 +325,7 @@ def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` total variation loss """ @@ -348,7 +348,7 @@ def TV_channel(cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Parameters ---------- - cube: :class:`torch.Tensor` of :class:`torch.double` + cube: :class:`torch.Tensor` the image cube array :math:`I_{lmv}` epsilon: float a softening parameter in units of [:math:`\mathrm{Jy}/\mathrm{arcsec}^2`]. @@ -357,7 +357,7 @@ def TV_channel(cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` total variation loss """ # calculate the difference between the n+1 cube and the n cube @@ -383,7 +383,7 @@ def TSV(sky_cube: torch.Tensor) -> torch.Tensor: Parameters ---------- - sky_cube :class:`torch.Tensor` of :class:`torch.double` + sky_cube :class:`torch.Tensor` the image cube array :math:`I_{lmv}`, where :math:`l` is R.A. in :math:`ndim=3`, :math:`m` is DEC in :math:`ndim=2`, and :math:`v` is the channel (velocity or frequency) dimension in @@ -391,7 +391,7 @@ def TSV(sky_cube: torch.Tensor) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` total square variation loss """ @@ -418,7 +418,7 @@ def sparsity(cube: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.T Parameters ---------- - cube : :class:`torch.Tensor` of :class:`torch.double` + cube : :class:`torch.Tensor` the image cube array :math:`I_{lmv}` mask : :class:`torch.Tensor` of :class:`torch.bool` tensor array the same shape as ``cube``. The sparsity prior @@ -427,7 +427,7 @@ def sparsity(cube: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.T Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` sparsity loss calculated where ``mask == True`` """ @@ -458,7 +458,7 @@ def UV_sparsity( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` UV sparsity loss above :math:`q_\mathrm{max}` """ @@ -498,16 +498,16 @@ def PSD(qs: torch.Tensor, psd: torch.Tensor, l: torch.Tensor) -> torch.Tensor: Parameters ---------- - qs : :class:`torch.Tensor` of :class:`torch.double` + qs : :class:`torch.Tensor` the radial UV coordinate (in :math:`\lambda`) - psd : :class:`torch.Tensor` of :class:`torch.double` + psd : :class:`torch.Tensor` the power spectral density cube - l : :class:`torch.Tensor` of :class:`torch.double` + l : :class:`torch.Tensor` the correlation length in the image plane (in arcsec) Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the loss calculated using the power spectral density """ @@ -531,12 +531,12 @@ def edge_clamp(cube: torch.Tensor) -> torch.Tensor: Parameters ---------- - cube: :class:`torch.Tensor` of :class:`torch.double` + cube: :class:`torch.Tensor` the image cube array :math:`I_{lmv}` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` edge loss """ diff --git a/src/mpol/utils.py b/src/mpol/utils.py index a9487172..bbd6c285 100644 --- a/src/mpol/utils.py +++ b/src/mpol/utils.py @@ -291,7 +291,7 @@ def get_optimal_image_properties( image_width : float, unit = arcsec Desired width of the image (for a square image of size `image_width` :math:`\times` `image_width`). - u, v : :class:`torch.Tensor` of :class:`torch.double`, unit = :math:`\lambda` + u, v : :class:`torch.Tensor` , unit = :math:`\lambda` `u` and `v` baselines. Returns diff --git a/test/conftest.py b/test/conftest.py index 6886d72f..d57329a0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,7 +21,7 @@ def img2D_butterfly(): """Return the 2D source image of the butterfly, for use as a test image cube.""" archive = np.load(_npz_path) - img = np.float64(archive["img"]) + img = archive["img"] # assuming we're going to go with _cell_size, set the total flux of this image # total flux should be 0.253 Jy from MPoL-examples. @@ -43,7 +43,7 @@ def packed_cube(img2D_butterfly): def baselines_m(): "Return the mock baselines (in meters) produced from the IM Lup DSHARP dataset." archive = np.load(_npz_path) - return np.float64(archive["uu"]), np.float64(archive["vv"]) + return archive["uu"], archive["vv"] @pytest.fixture(scope="session") diff --git a/test/datasets_test.py b/test/datasets_test.py index 9071ff67..34923ed5 100644 --- a/test/datasets_test.py +++ b/test/datasets_test.py @@ -15,11 +15,9 @@ def test_index(coords, dataset): # create a mock cube that includes negative values nchan = dataset.nchan mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # tensor base_cube = torch.normal(mean=mean, std=std) diff --git a/test/fourier_test.py b/test/fourier_test.py index a9215e4e..23ff3edc 100644 --- a/test/fourier_test.py +++ b/test/fourier_test.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import numpy as np import torch -from mpol import fourier, images, utils +from mpol import fourier, utils from pytest import approx @@ -181,44 +181,45 @@ def test_predict_vis_nufft_cached(coords, baselines_1D): # if the image cube was filled with zeros, then we should make sure this is true assert output.detach().numpy() == approx( - np.zeros((nchan, len(uu)), dtype=np.complex128) + np.zeros((nchan, len(uu))) ) def test_nufft_cached_predict_GPU(coords, baselines_1D): - if not torch.cuda.is_available(): - pass + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") else: - device = torch.device("cuda:0") - - # just see that we can load the layer and get something through without error - # for a very simple blank function + device = torch.device("cpu") + return - # load some data - uu, vv = baselines_1D + # just see that we can load the layer and get something through without error + # for a very simple blank function - nchan = 10 + # load some data + uu, vv = baselines_1D - # instantiate an ImageCube layer filled with zeros and send to GPU - imagecube = images.ImageCube(coords=coords, nchan=nchan).to(device=device) + nchan = 10 - # we have a multi-channel cube, but only sent single-channel uu and vv - # coordinates. The expectation is that TorchKbNufft will parallelize these + # we have a multi-channel cube, but only sent single-channel uu and vv + # coordinates. The expectation is that TorchKbNufft will parallelize these - layer = fourier.NuFFTCached(coords=coords, nchan=nchan, uu=uu, vv=vv).to( - device=device - ) + layer = fourier.NuFFTCached(coords=coords, nchan=nchan, uu=uu, vv=vv).to( + device=device + ) - # predict the values of the cube at the u,v locations - output = layer(imagecube()) + # predict the values of the cube at the u,v locations + blank_packed_img = torch.zeros((nchan, coords.npix, coords.npix)).to(device=device) + output = layer(blank_packed_img) - # make sure we got back the number of visibilities we expected - assert output.shape == (nchan, len(uu)) + # make sure we got back the number of visibilities we expected + assert output.shape == (nchan, len(uu)) - # if the image cube was filled with zeros, then we should make sure this is true - assert output.cpu().detach().numpy() == approx( - np.zeros((nchan, len(uu)), dtype=np.complex128) - ) + # if the image cube was filled with zeros, then we should make sure this is true + assert output.cpu().detach().numpy() == approx( + np.zeros((nchan, len(uu)), dtype=np.complex128) + ) def test_nufft_accuracy_single_chan(coords, baselines_1D, tmp_path): @@ -316,7 +317,7 @@ def test_nufft_cached_accuracy_single_chan(coords, baselines_1D, tmp_path): img_packed = utils.sky_gaussian_arcsec( coords.packed_x_centers_2D, coords.packed_y_centers_2D, **kw ) - img_packed_tensor = torch.tensor(img_packed[np.newaxis, :, :], requires_grad=True) + img_packed_tensor = torch.tensor(img_packed[np.newaxis, :, :], requires_grad=True, dtype=torch.float32) # use the NuFFT to predict the values of the cube at the u,v locations num_output = layer(img_packed_tensor)[0] # take the channel dim out @@ -392,7 +393,7 @@ def test_nufft_cached_accuracy_coil_broadcast(coords, baselines_1D): # broadcast to 5 channels -- the image will be the same for each img_packed_tensor = torch.tensor( img_packed[np.newaxis, :, :] * np.ones((nchan, coords.npix, coords.npix)), - requires_grad=True, + requires_grad=True, dtype=torch.float32 ) # use the NuFFT to predict the values of the cube at the u,v locations diff --git a/test/images_test.py b/test/images_test.py index 920e2df5..cc5eb6a4 100644 --- a/test/images_test.py +++ b/test/images_test.py @@ -56,11 +56,9 @@ def test_basecube_imagecube(coords, tmp_path): # create a mock cube that includes negative values nchan = 1 mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # tensor base_cube = torch.normal(mean=mean, std=std) @@ -111,11 +109,9 @@ def test_base_cube_conv_cube(coords, tmp_path): # create a mock cube that includes negative values nchan = 1 mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # The HannConvCube expects to function on a pre-packed ImageCube, # so in order to get the plots looking correct on this test image, @@ -156,11 +152,9 @@ def test_multi_chan_conv(coords, tmp_path): nchan = 10 mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # tensor test_cube = torch.normal(mean=mean, std=std)