Skip to content

Commit

Permalink
Merge pull request #255 from MPoL-dev/float32
Browse files Browse the repository at this point in the history
switched to float32 default throughout codebase.
  • Loading branch information
iancze authored Feb 29, 2024
2 parents 0e0a576 + 13a3a06 commit 3c5da5b
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 110 deletions.
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Changelog

## v0.3.0

- removed explicit type declarations in base MPoL modules. Previously, core representations were set to be in `float64` or `complex128`. Now, core MPoL representations (e.g., {class}`mpol.images.BaseCube`) will follow the [default tensor type](https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html), which is commonly `torch.float32`. If you want your model to run fully in `float32` or `complex64`, then be sure that your data is also in these formats, since otherwise PyTorch will promote downstream tensors as needed. Fully `float32` or `complex64` models should be able to run on Apple MPS [#254](https://github.com/MPoL-dev/MPoL/issues/254)
- added {meth}`mpol.utils.convolve_packed_cube` method to convolve a 3D packed image cube with a 2D Gaussian. You can specify major axis, minor axis, and rotation angle.
- added {meth}`mpol.utils.uv_gaussian_taper` to calculate a Gaussian tapering window in the visibility plane.
- added the `vis_ext_Mlam` instance attribute to {class}`mpol.coordinates.GridCoords` for convenience plotting of visibility grids with axes labels in units of M$\lambda$.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ target-version = "py310"
line-length = 88
# will enable after sorting module locations
# select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"]
ignore = [
lint.ignore = [
"E741", # Allow ambiguous variable names
"PLR0911", # Allow many return statements
"PLR0913", # Allow many arguments to functions
Expand Down
4 changes: 2 additions & 2 deletions src/mpol/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def check_data_fit(
Parameters
----------
uu : :class:`torch.Tensor` of `torch.double`
uu : :class:`torch.Tensor`
u spatial frequency coordinates.
Units of [:math:`\lambda`]
vv : :class:`torch.Tensor` of `torch.double`
vv : :class:`torch.Tensor`
v spatial frequency coordinates.
Units of [:math:`\lambda`]
Expand Down
2 changes: 1 addition & 1 deletion src/mpol/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GriddedDataset(torch.nn.Module):
If providing this, cannot provide ``cell_size`` or ``npix``.
vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128`
the gridded visibility data stored in a "packed" format (pre-shifted for fft)
weight_gridded : :class:`torch.Tensor` of :class:`torch.double`
weight_gridded : :class:`torch.Tensor`
the weights corresponding to the gridded visibility data,
also in a packed format
mask : :class:`torch.Tensor` of :class:`torch.bool`
Expand Down
34 changes: 18 additions & 16 deletions src/mpol/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:
Parameters
----------
cube : :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``
cube : :class:`torch.Tensor` of shape ``(nchan, npix, npix)``
A 'packed' tensor. For example, an image cube from
:meth:`mpol.images.ImageCube.forward`
Returns
-------
:class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``.
:class:`torch.Tensor` of shape ``(nchan, npix, npix)``.
The FFT of the image cube, in packed format.
"""

Expand Down Expand Up @@ -89,7 +89,7 @@ def ground_amp(self) -> torch.Tensor:
Returns
-------
:class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``
:class:`torch.Tensor` of shape ``(nchan, npix, npix)``
amplitude cube in 'ground' format.
"""
return torch.abs(self.ground_vis)
Expand All @@ -102,7 +102,7 @@ def ground_phase(self) -> torch.Tensor:
Returns
-------
:class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``
:class:`torch.Tensor` of shape ``(nchan, npix, npix)``
phase cube in 'ground' format (:math:`[-\pi,\pi)`).
"""
return torch.angle(self.ground_vis)
Expand Down Expand Up @@ -275,14 +275,14 @@ def forward(
Parameters
----------
packed_cube : :class:`torch.Tensor` of :class:`torch.double`
packed_cube : :class:`torch.Tensor`
shape ``(nchan, npix, npix)``). The cube
should be a "prepacked" image cube, for example,
from :meth:`mpol.images.ImageCube.forward`
uu : :class:`torch.Tensor` of :class:`torch.double`
uu : :class:`torch.Tensor`
2D array of the u (East-West) spatial frequency coordinate
[:math:`\lambda`] of shape ``(nchan, npix)``
vv : :class:`torch.Tensor` of :class:`torch.double`
vv : :class:`torch.Tensor`
2D array of the v (North-South) spatial frequency coordinate
[:math:`\lambda`] (must be the same shape as uu)
sparse_matrices : bool
Expand Down Expand Up @@ -368,7 +368,7 @@ def forward(
shifted = torch.fft.fftshift(packed_cube, dim=(1, 2))

# convert the cube to a complex type, since this is required by TorchKbNufft
complexed = shifted.type(torch.complex128)
complexed = shifted + 0j

k_traj = self._assemble_ktraj(uu, vv)

Expand Down Expand Up @@ -498,17 +498,18 @@ def __init__(
self.real_interp_mat: torch.Tensor
self.imag_interp_mat: torch.Tensor

def forward(self, cube):
def forward(self, packed_cube):
r"""
Perform the FFT of the image cube for each channel and interpolate to the
``uu`` and ``vv`` points set at layer initialization. This call should
automatically take the best parallelization option as set by the shape of the
``uu`` and ``vv`` points.
Args:
cube (torch.double tensor): of shape ``(nchan, npix, npix)``). The cube
should be a "prepacked" image cube, for example, from
:meth:`mpol.images.ImageCube.forward`
packed_cube : :class:`torch.Tensor`
shape ``(nchan, npix, npix)``). The cube
should be a "prepacked" image cube, for example,
from :meth:`mpol.images.ImageCube.forward`
Returns:
torch.complex tensor: of shape ``(nchan, nvis)``, Fourier samples evaluated
Expand All @@ -517,17 +518,18 @@ def forward(self, cube):

# make sure that the nchan assumptions for the ImageCube and the NuFFT
# setup are the same
if cube.size(0) != self.nchan:
if packed_cube.size(0) != self.nchan:
raise DimensionMismatchError(
f"nchan of ImageCube ({cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})"
f"nchan of ImageCube ({packed_cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})"
)

# "unpack" the cube, but leave it flipped
# NuFFT routine expects a "normal" cube, not an fftshifted one
shifted = torch.fft.fftshift(cube, dim=(1, 2))
shifted = torch.fft.fftshift(packed_cube, dim=(1, 2))

# convert the cube to a complex type, since this is required by TorchKbNufft
complexed = shifted.type(torch.complex128)
complexed = shifted + 0j


# Consider how the similarity of the spatial frequency samples should be
# treated. We already took care of this on the k_traj side, since we set
Expand Down
12 changes: 6 additions & 6 deletions src/mpol/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def flat_to_observer(
Parameters
----------
x : :class:`torch.Tensor` of :class:`torch.double`
x : :class:`torch.Tensor`
A tensor representing the x coordinate in the plane of the orbit.
y : :class:`torch.Tensor` of :class:`torch.double`
y : :class:`torch.Tensor`
A tensor representing the y coordinate in the plane of the orbit.
omega : float
Argument of periastron [radians]. Default 0.0.
Expand All @@ -43,7 +43,7 @@ def flat_to_observer(
Returns
-------
2-tuple of :class:`torch.Tensor` of :class:`torch.double`
2-tuple of :class:`torch.Tensor`
Two tensors representing ``(X, Y)`` in the observer frame.
"""
# Rotation matrices result in a *clockwise* rotation of the axes,
Expand Down Expand Up @@ -100,9 +100,9 @@ def observer_to_flat(
Parameters
----------
X : :class:`torch.Tensor` of :class:`torch.double`
X : :class:`torch.Tensor`
A tensor representing the x coordinate in the plane of the sky.
Y : :class:`torch.Tensor` of :class:`torch.double`
Y : :class:`torch.Tensor`
A tensor representing the y coordinate in the plane of the sky.
omega : float
A tensor representing an argument of periastron [radians] Default 0.0.
Expand All @@ -114,7 +114,7 @@ def observer_to_flat(
Returns
-------
2-tuple of :class:`torch.Tensor` of :class:`torch.double`
2-tuple of :class:`torch.Tensor`
Two tensors representing ``(x, y)`` in the flat frame.
"""
# Rotation matrices result in a *clockwise* rotation of the axes,
Expand Down
15 changes: 7 additions & 8 deletions src/mpol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
* torch.ones(
(self.nchan, self.coords.npix, self.coords.npix),
requires_grad=True,
dtype=torch.double,
)
)

Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None:
# bias has shape (nchan)

# build out the discretely-sampled Hann kernel
spec = torch.tensor([0.25, 0.5, 0.25], dtype=torch.double)
spec = torch.tensor([0.25, 0.5, 0.25])
nugget = torch.outer(spec, spec) # shape (3,3) 2D Hann kernel
exp = torch.unsqueeze(torch.unsqueeze(nugget, 0), 0) # shape (1, 1, 3, 3)
weight = exp.repeat(nchan, 1, 1, 1) # shape (nchan, 1, 3, 3)
Expand All @@ -158,7 +157,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None:

# set the bias to zero
self.m.bias = nn.Parameter(
torch.zeros(nchan, dtype=torch.double), requires_grad=requires_grad
torch.zeros(nchan), requires_grad=requires_grad
)

def forward(self, cube: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -239,7 +238,7 @@ def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:
Returns
-------
:class:`torch.Tensor` of :class:`torch.double` type
:class:`torch.Tensor` type
tensor of shape ``(nchan, npix, npix)``), same as `cube`
"""
self.packed_cube = packed_cube
Expand Down Expand Up @@ -337,7 +336,7 @@ def uv_gaussian_taper(
Returns
-------
:class:`torch.Tensor` of :class:`torch.double`, shape ``(npix, npix)``
:class:`torch.Tensor` , shape ``(npix, npix)``
"""

# convert FWHM to sigma and to radians
Expand Down Expand Up @@ -379,7 +378,7 @@ def convolve_packed_cube(
Parameters
----------
packed_cube : :class:`torch.Tensor` of :class:`torch.double` type
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
coords: :class:`mpol.coordinates.GridCoords`
object indicating image and Fourier grid specifications.
Expand All @@ -392,7 +391,7 @@ def convolve_packed_cube(
Returns
-------
:class:`torch.Tensor` of :class:`torch.double`
:class:`torch.Tensor`
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (npix_m == coords.npix) and (
Expand All @@ -414,7 +413,7 @@ def convolve_packed_cube(
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-10
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
Expand Down
Loading

0 comments on commit 3c5da5b

Please sign in to comment.