Skip to content

Commit

Permalink
Merge pull request #252 from MPoL-dev/sgd
Browse files Browse the repository at this point in the history
Stochastic Gradient Descent (SGD) example and functionality
  • Loading branch information
iancze authored Apr 26, 2024
2 parents 84e1b7d + 4c7b453 commit d06fc25
Show file tree
Hide file tree
Showing 41 changed files with 860 additions and 379 deletions.
8 changes: 2 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,9 @@ jobs:
- name: Install test dependencies
run: |
pip install .[test]
- name: Lint with flake8
- name: Lint with ruff
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
ruff check .
- name: Check types with MyPy
run: |
mypy src/mpol --pretty
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,6 @@ plotsdir
runs

# hatch-generated version file
src/mpol/mpol_version.py
src/mpol/mpol_version.py

.ruff_cache
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Contributors
* Hannah Grzybowski, `@hgrzy`
* Mary Ogborn
* Tyler Quinn, `@trq5014`
* Kristin Hopley
20 changes: 10 additions & 10 deletions docs/_static/baselines/src/print_conversions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import csv

import numpy as np
from mpol.constants import c_ms

import argparse

parser = argparse.ArgumentParser(
Expand All @@ -6,11 +11,6 @@
parser.add_argument("outfile", help="Destination to save CSV table.")
args = parser.parse_args()

import csv

import numpy as np

from mpol.constants import c_ms

header = ["baseline", "100 GHz (Band 3)", "230 GHz (Band 6)", "340 GHz (Band 7)"]

Expand All @@ -20,18 +20,18 @@

def format_baseline(baseline_m):
if baseline_m < 1e3:
return "{:.0f} m".format(baseline_m)
return f"{baseline_m:.0f} m"
elif baseline_m < 1e6:
return "{:.0f} km".format(baseline_m * 1e-3)
return f"{baseline_m * 1e-3:.0f} km"


def format_lambda(lam):
if lam < 1e3:
return "{:.0f}".format(lam) + " :math:`\lambda`"
return f"{lam:.0f}" + r" :math:`\lambda`"
elif lam < 1e6:
return "{:.0f}".format(lam * 1e-3) + " :math:`\mathrm{k}\lambda`"
return f"{lam * 1e-3:.0f}" + r" :math:`\mathrm{k}\lambda`"
else:
return "{:.0f}".format(lam * 1e-6) + " :math:`\mathrm{M}\lambda`"
return f"{lam * 1e-6:.0f}" + r" :math:`\mathrm{M}\lambda`"


data = []
Expand Down
14 changes: 7 additions & 7 deletions docs/_static/fftshift/src/plot.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import argparse

parser = argparse.ArgumentParser(description="Create the fftshift plot")
parser.add_argument("outfile", help="Destination to save plot.")
args = parser.parse_args()

import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from astropy.utils.data import download_file
from matplotlib import patches
from matplotlib.colors import LogNorm
from matplotlib.gridspec import GridSpec

from mpol import coordinates

import argparse

parser = argparse.ArgumentParser(description="Create the fftshift plot")
parser.add_argument("outfile", help="Destination to save plot.")
args = parser.parse_args()


fname = download_file(
"https://zenodo.org/record/4711811/files/logo_cont.fits",
cache=True,
Expand Down
6 changes: 5 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
# Changelog

## v0.3.0

- removed explicit type declarations in base MPoL modules. Previously, core representations were set to be in `float64` or `complex128`. Now, core MPoL representations (e.g., {class}`mpol.images.BaseCube`) will follow the [default tensor type](https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html), which is commonly `torch.float32`. If you want your model to run fully in `float32` or `complex64`, then be sure that your data is also in these formats, since otherwise PyTorch will promote downstream tensors as needed. Fully `float32` or `complex64` models should be able to run on Apple MPS [#254](https://github.com/MPoL-dev/MPoL/issues/254)
- added {meth}`mpol.utils.convolve_packed_cube` method to convolve a 3D packed image cube with a 2D Gaussian. You can specify major axis, minor axis, and rotation angle.
- added {meth}`mpol.utils.uv_gaussian_taper` to calculate a Gaussian tapering window in the visibility plane.
- added the `vis_ext_Mlam` instance attribute to {class}`mpol.coordinates.GridCoords` for convenience plotting of visibility grids with axes labels in units of M$\lambda$.
- Updated [MPoL-dev/examples](https://github.com/MPoL-dev/examples) with Stochastic Gradient Descent Example.
- Standardized nomenclature of {class}`mpol.coordinates.GridCoords` and {class}`mpol.fourier.FourierCube` to use `sky_cube` for a normal image and `ground_cube` for a normal visibility cube (rather than `sky_` for visibility quantities). Routines use `packed_cube` instead of `cube` internally to be clear when packed format is preferred.
- Modified {class}`mpol.coordinates.GridCoords` object to use cached properties [#187](https://github.com/MPoL-dev/MPoL/pull/187).
- Changed the base spatial frequency unit from k$\lambda$ to $\lambda$, addressing [#223](https://github.com/MPoL-dev/MPoL/issues/223). This will affect most users data-reading routines!
Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os

# -- Project information -----------------------------------------------------
from pkg_resources import DistributionNotFound, get_distribution
Expand Down Expand Up @@ -46,7 +45,7 @@
autodoc_mock_imports = ["torch", "torchvision"]
autodoc_member_order = "bysource"
# https://github.com/sphinx-doc/sphinx/issues/9709
# bug that if we set this here, we can't list individual members in the
# bug that if we set this here, we can't list individual members in the
# actual API doc
# autodoc_default_options = {"members": None}

Expand Down
21 changes: 18 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dev = [
"mypy",
"frank>=1.2.1",
"sphinx>=7.2.0",
"sphinx-autodoc2",
"jupytext",
"ipython!=8.7.0", # broken version for syntax higlight https://github.com/spatialaudio/nbsphinx/issues/687
"nbsphinx",
Expand All @@ -51,7 +50,8 @@ dev = [
"asdf",
"pyro-ppl",
"arviz[all]",
"visread>=0.0.4"
"visread>=0.0.4",
"ruff"
]
test = [
"pytest",
Expand All @@ -62,6 +62,7 @@ test = [
"mypy",
"visread>=0.0.4",
"frank>=1.2.1",
"ruff"
]

[project.urls]
Expand Down Expand Up @@ -105,4 +106,18 @@ module = [
"MPoL.precomposed",
"MPoL.utils"
]
disallow_untyped_defs = true
disallow_untyped_defs = true

[tool.ruff]
target-version = "py310"
line-length = 88
# will enable after sorting module locations
# select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"]
lint.ignore = [
"E741", # Allow ambiguous variable names
"PLR0911", # Allow many return statements
"PLR0913", # Allow many arguments to functions
"PLR0915", # Allow many statements
"PLR2004", # Allow magic numbers in comparisons
]
exclude = []
1 change: 0 additions & 1 deletion src/mpol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from mpol.mpol_version import __version__
zenodo_record = 10064221
25 changes: 14 additions & 11 deletions src/mpol/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import cached_property

from functools import cached_property
from typing import Any

import numpy as np
Expand All @@ -10,7 +10,7 @@

import mpol.constants as const
from mpol.exceptions import CellSizeError
from mpol.utils import get_max_spatial_freq, get_maximum_cell_size
from mpol.utils import get_maximum_cell_size


class GridCoords:
Expand Down Expand Up @@ -79,6 +79,7 @@ class GridCoords:
:ivar vis_ext: length-4 list of (left, right, bottom, top) expected by routines
like ``matplotlib.pyplot.imshow`` in the ``extent`` parameter assuming
``origin='lower'``. Units of [:math:`\lambda`]
:ivar vis_ext_Mlam: like vis_ext, but in units of [:math:`\mathrm{M}\lambda`].
"""

def __init__(self, cell_size: float, npix: int):
Expand Down Expand Up @@ -205,16 +206,18 @@ def vis_ext(self) -> list[float]:
self.u_bin_max,
self.v_bin_min,
self.v_bin_max,
] # [kλ]
] # [λ]

@property
def vis_ext_Mlam(self) -> list[float]:
return [1e-6 * edge for edge in self.vis_ext]

# --------------------------------------------------------------------------
# Non-identical u & v properties
# --------------------------------------------------------------------------
@cached_property
def ground_u_centers_2D(self) -> npt.NDArray[np.floating[Any]]:
# only useful for plotting a sky_vis
# uu increasing, no fftshift
# tile replicates the 1D u_centers array to a 2D array the size of the full UV grid
# tile replicates the 1D u_centers array to a 2D array the size of the full
# UV grid
return np.tile(self.u_centers, (self.npix_u, 1))

@cached_property
Expand Down Expand Up @@ -304,10 +307,10 @@ def check_data_fit(
Parameters
----------
uu : :class:`torch.Tensor` of `torch.double`
uu : :class:`torch.Tensor`
u spatial frequency coordinates.
Units of [:math:`\lambda`]
vv : :class:`torch.Tensor` of `torch.double`
vv : :class:`torch.Tensor`
v spatial frequency coordinates.
Units of [:math:`\lambda`]
Expand Down Expand Up @@ -354,6 +357,6 @@ def __eq__(self, other: Any) -> bool:
# don't attempt to compare against different types
return NotImplemented

# GridCoords objects are considered equal if they have the same cell_size and npix, since
# all other attributes are derived from these two core properties.
# GridCoords objects are considered equal if they have the same cell_size and
# npix, since all other attributes are derived from these two core properties.
return bool(self.cell_size == other.cell_size and self.npix == other.npix)
8 changes: 3 additions & 5 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import logging
from collections import defaultdict
from typing import Any

import numpy as np
Expand All @@ -11,11 +10,9 @@
from numpy.typing import NDArray

from mpol.datasets import Dartboard, GriddedDataset
from mpol.precomposed import GriddedNet

# from mpol.training import TrainTest, train_to_dirty_image
# from mpol.training import TrainTest, train_to_dirty_image
from mpol.plot import split_diagnostics_fig
from mpol.utils import loglinspace


# class CrossValidate:
Expand Down Expand Up @@ -59,7 +56,8 @@
# Number of k-folds to use in cross-validation
# split_method : str, default='dartboard'
# Method to split full dataset into train/test subsets
# dartboard_q_edges, dartboard_phi_edges : list of float, default=None, unit=[klambda]
# dartboard_q_edges, dartboard_phi_edges : list of float, default=None,
# unit=[klambda]
# Radial and azimuthal bin edges of the cells used to split the dataset
# if `split_method`==`dartboard` (see `datasets.Dartboard`)
# split_diag_fig : bool, default=False
Expand Down
5 changes: 2 additions & 3 deletions src/mpol/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from numpy import floating, integer
from numpy.typing import ArrayLike, NDArray

from mpol.coordinates import GridCoords

from mpol import utils
from mpol.coordinates import GridCoords


class GriddedDataset(torch.nn.Module):
Expand All @@ -20,7 +19,7 @@ class GriddedDataset(torch.nn.Module):
If providing this, cannot provide ``cell_size`` or ``npix``.
vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128`
the gridded visibility data stored in a "packed" format (pre-shifted for fft)
weight_gridded : :class:`torch.Tensor` of :class:`torch.double`
weight_gridded : :class:`torch.Tensor`
the weights corresponding to the gridded visibility data,
also in a packed format
mask : :class:`torch.Tensor` of :class:`torch.bool`
Expand Down
Loading

0 comments on commit d06fc25

Please sign in to comment.