Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switched to float32 default throughout codebase. #255

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Changelog

## v0.3.0

- removed explicit type declarations in base MPoL modules. Previously, core representations were set to be in `float64` or `complex128`. Now, core MPoL representations (e.g., {class}`mpol.images.BaseCube`) will follow the [default tensor type](https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html), which is commonly `torch.float32`. If you want your model to run fully in `float32` or `complex64`, then be sure that your data is also in these formats, since otherwise PyTorch will promote downstream tensors as needed. Fully `float32` or `complex64` models should be able to run on Apple MPS [#254](https://github.com/MPoL-dev/MPoL/issues/254)
- added {meth}`mpol.utils.convolve_packed_cube` method to convolve a 3D packed image cube with a 2D Gaussian. You can specify major axis, minor axis, and rotation angle.
- added {meth}`mpol.utils.uv_gaussian_taper` to calculate a Gaussian tapering window in the visibility plane.
- added the `vis_ext_Mlam` instance attribute to {class}`mpol.coordinates.GridCoords` for convenience plotting of visibility grids with axes labels in units of M$\lambda$.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ target-version = "py310"
line-length = 88
# will enable after sorting module locations
# select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"]
ignore = [
lint.ignore = [
"E741", # Allow ambiguous variable names
"PLR0911", # Allow many return statements
"PLR0913", # Allow many arguments to functions
Expand Down
4 changes: 2 additions & 2 deletions src/mpol/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def check_data_fit(

Parameters
----------
uu : :class:`torch.Tensor` of `torch.double`
uu : :class:`torch.Tensor`
u spatial frequency coordinates.
Units of [:math:`\lambda`]
vv : :class:`torch.Tensor` of `torch.double`
vv : :class:`torch.Tensor`
v spatial frequency coordinates.
Units of [:math:`\lambda`]

Expand Down
2 changes: 1 addition & 1 deletion src/mpol/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GriddedDataset(torch.nn.Module):
If providing this, cannot provide ``cell_size`` or ``npix``.
vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128`
the gridded visibility data stored in a "packed" format (pre-shifted for fft)
weight_gridded : :class:`torch.Tensor` of :class:`torch.double`
weight_gridded : :class:`torch.Tensor`
the weights corresponding to the gridded visibility data,
also in a packed format
mask : :class:`torch.Tensor` of :class:`torch.bool`
Expand Down
34 changes: 18 additions & 16 deletions src/mpol/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:

Parameters
----------
cube : :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``
cube : :class:`torch.Tensor` of shape ``(nchan, npix, npix)``
A 'packed' tensor. For example, an image cube from
:meth:`mpol.images.ImageCube.forward`

Returns
-------
:class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``.
:class:`torch.Tensor` of shape ``(nchan, npix, npix)``.
The FFT of the image cube, in packed format.
"""

Expand Down Expand Up @@ -89,7 +89,7 @@ def ground_amp(self) -> torch.Tensor:

Returns
-------
:class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``
:class:`torch.Tensor` of shape ``(nchan, npix, npix)``
amplitude cube in 'ground' format.
"""
return torch.abs(self.ground_vis)
Expand All @@ -102,7 +102,7 @@ def ground_phase(self) -> torch.Tensor:

Returns
-------
:class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``
:class:`torch.Tensor` of shape ``(nchan, npix, npix)``
phase cube in 'ground' format (:math:`[-\pi,\pi)`).
"""
return torch.angle(self.ground_vis)
Expand Down Expand Up @@ -275,14 +275,14 @@ def forward(

Parameters
----------
packed_cube : :class:`torch.Tensor` of :class:`torch.double`
packed_cube : :class:`torch.Tensor`
shape ``(nchan, npix, npix)``). The cube
should be a "prepacked" image cube, for example,
from :meth:`mpol.images.ImageCube.forward`
uu : :class:`torch.Tensor` of :class:`torch.double`
uu : :class:`torch.Tensor`
2D array of the u (East-West) spatial frequency coordinate
[:math:`\lambda`] of shape ``(nchan, npix)``
vv : :class:`torch.Tensor` of :class:`torch.double`
vv : :class:`torch.Tensor`
2D array of the v (North-South) spatial frequency coordinate
[:math:`\lambda`] (must be the same shape as uu)
sparse_matrices : bool
Expand Down Expand Up @@ -368,7 +368,7 @@ def forward(
shifted = torch.fft.fftshift(packed_cube, dim=(1, 2))

# convert the cube to a complex type, since this is required by TorchKbNufft
complexed = shifted.type(torch.complex128)
complexed = shifted + 0j

k_traj = self._assemble_ktraj(uu, vv)

Expand Down Expand Up @@ -498,17 +498,18 @@ def __init__(
self.real_interp_mat: torch.Tensor
self.imag_interp_mat: torch.Tensor

def forward(self, cube):
def forward(self, packed_cube):
r"""
Perform the FFT of the image cube for each channel and interpolate to the
``uu`` and ``vv`` points set at layer initialization. This call should
automatically take the best parallelization option as set by the shape of the
``uu`` and ``vv`` points.

Args:
cube (torch.double tensor): of shape ``(nchan, npix, npix)``). The cube
should be a "prepacked" image cube, for example, from
:meth:`mpol.images.ImageCube.forward`
packed_cube : :class:`torch.Tensor`
shape ``(nchan, npix, npix)``). The cube
should be a "prepacked" image cube, for example,
from :meth:`mpol.images.ImageCube.forward`

Returns:
torch.complex tensor: of shape ``(nchan, nvis)``, Fourier samples evaluated
Expand All @@ -517,17 +518,18 @@ def forward(self, cube):

# make sure that the nchan assumptions for the ImageCube and the NuFFT
# setup are the same
if cube.size(0) != self.nchan:
if packed_cube.size(0) != self.nchan:
raise DimensionMismatchError(
f"nchan of ImageCube ({cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})"
f"nchan of ImageCube ({packed_cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})"
)

# "unpack" the cube, but leave it flipped
# NuFFT routine expects a "normal" cube, not an fftshifted one
shifted = torch.fft.fftshift(cube, dim=(1, 2))
shifted = torch.fft.fftshift(packed_cube, dim=(1, 2))

# convert the cube to a complex type, since this is required by TorchKbNufft
complexed = shifted.type(torch.complex128)
complexed = shifted + 0j


# Consider how the similarity of the spatial frequency samples should be
# treated. We already took care of this on the k_traj side, since we set
Expand Down
12 changes: 6 additions & 6 deletions src/mpol/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def flat_to_observer(

Parameters
----------
x : :class:`torch.Tensor` of :class:`torch.double`
x : :class:`torch.Tensor`
A tensor representing the x coordinate in the plane of the orbit.
y : :class:`torch.Tensor` of :class:`torch.double`
y : :class:`torch.Tensor`
A tensor representing the y coordinate in the plane of the orbit.
omega : float
Argument of periastron [radians]. Default 0.0.
Expand All @@ -43,7 +43,7 @@ def flat_to_observer(

Returns
-------
2-tuple of :class:`torch.Tensor` of :class:`torch.double`
2-tuple of :class:`torch.Tensor`
Two tensors representing ``(X, Y)`` in the observer frame.
"""
# Rotation matrices result in a *clockwise* rotation of the axes,
Expand Down Expand Up @@ -100,9 +100,9 @@ def observer_to_flat(

Parameters
----------
X : :class:`torch.Tensor` of :class:`torch.double`
X : :class:`torch.Tensor`
A tensor representing the x coordinate in the plane of the sky.
Y : :class:`torch.Tensor` of :class:`torch.double`
Y : :class:`torch.Tensor`
A tensor representing the y coordinate in the plane of the sky.
omega : float
A tensor representing an argument of periastron [radians] Default 0.0.
Expand All @@ -114,7 +114,7 @@ def observer_to_flat(

Returns
-------
2-tuple of :class:`torch.Tensor` of :class:`torch.double`
2-tuple of :class:`torch.Tensor`
Two tensors representing ``(x, y)`` in the flat frame.
"""
# Rotation matrices result in a *clockwise* rotation of the axes,
Expand Down
15 changes: 7 additions & 8 deletions src/mpol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
* torch.ones(
(self.nchan, self.coords.npix, self.coords.npix),
requires_grad=True,
dtype=torch.double,
)
)

Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None:
# bias has shape (nchan)

# build out the discretely-sampled Hann kernel
spec = torch.tensor([0.25, 0.5, 0.25], dtype=torch.double)
spec = torch.tensor([0.25, 0.5, 0.25])
nugget = torch.outer(spec, spec) # shape (3,3) 2D Hann kernel
exp = torch.unsqueeze(torch.unsqueeze(nugget, 0), 0) # shape (1, 1, 3, 3)
weight = exp.repeat(nchan, 1, 1, 1) # shape (nchan, 1, 3, 3)
Expand All @@ -158,7 +157,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None:

# set the bias to zero
self.m.bias = nn.Parameter(
torch.zeros(nchan, dtype=torch.double), requires_grad=requires_grad
torch.zeros(nchan), requires_grad=requires_grad
)

def forward(self, cube: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -239,7 +238,7 @@ def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:

Returns
-------
:class:`torch.Tensor` of :class:`torch.double` type
:class:`torch.Tensor` type
tensor of shape ``(nchan, npix, npix)``), same as `cube`
"""
self.packed_cube = packed_cube
Expand Down Expand Up @@ -337,7 +336,7 @@ def uv_gaussian_taper(

Returns
-------
:class:`torch.Tensor` of :class:`torch.double`, shape ``(npix, npix)``
:class:`torch.Tensor` , shape ``(npix, npix)``
"""

# convert FWHM to sigma and to radians
Expand Down Expand Up @@ -379,7 +378,7 @@ def convolve_packed_cube(

Parameters
----------
packed_cube : :class:`torch.Tensor` of :class:`torch.double` type
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
coords: :class:`mpol.coordinates.GridCoords`
object indicating image and Fourier grid specifications.
Expand All @@ -392,7 +391,7 @@ def convolve_packed_cube(

Returns
-------
:class:`torch.Tensor` of :class:`torch.double`
:class:`torch.Tensor`
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (npix_m == coords.npix) and (
Expand All @@ -414,7 +413,7 @@ def convolve_packed_cube(
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-10
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
Expand Down
Loading
Loading