Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 19, 2024
1 parent 731a280 commit 760852f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 31 deletions.
14 changes: 6 additions & 8 deletions src/arpes/utilities/conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import MOMENTUM
from arpes._typing import MOMENTUM, KspaceCoords

from .calibration import DetectorCalibration

Expand Down Expand Up @@ -150,7 +150,7 @@ def get_coordinates(
self,
resolution: dict[MOMENTUM, float] | None = None,
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
) -> dict[str, NDArray[np.float_]]:
) -> KspaceCoords:
"""Calculates the coordinates which should be used in momentum space.
Args:
Expand All @@ -159,12 +159,10 @@ def get_coordinates(
bounds(dict, optional): bounds of the momentum coordinates
Returns:
dict[str, NDArray]: the key represents the axis name suchas "kp", "kx", and "eV".
KspaceCoords: the key represents the axis name suchas "kp", "kx", and "eV".
"""
if resolution is None:
resolution = {}
if bounds is None:
bounds = {}
coordinates: dict[str, NDArray[np.float_]] = {}
resolution = resolution if resolution is not None else {}
bounds = bounds if bounds is not None else {}
coordinates: KspaceCoords = {}
coordinates["eV"] = self.arr.coords["eV"].values
return coordinates
16 changes: 8 additions & 8 deletions src/arpes/utilities/conversion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections.abc import Hashable
from itertools import pairwise
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack

import numpy as np
import xarray as xr
Expand All @@ -52,7 +52,7 @@

from numpy.typing import NDArray

from arpes._typing import MOMENTUM, XrTypes
from arpes._typing import MOMENTUM, KspaceCoords, XrTypes
from arpes.utilities.conversion.calibration import DetectorCalibration

__all__ = ["convert_to_kspace", "slice_along_path"]
Expand Down Expand Up @@ -284,9 +284,9 @@ def convert_to_kspace( # noqa: PLR0913
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
resolution: dict[MOMENTUM, float] | None = None,
calibration: DetectorCalibration | None = None,
coords: dict[MOMENTUM, NDArray[np.float_]] | None = None,
coords: KspaceCoords | None = None,
allow_chunks: bool = False,
**kwargs: NDArray[np.float_],
**kwargs: Unpack[KspaceCoords],
) -> xr.DataArray:
"""Converts volumetric the data to momentum space ("backwards"). Typically what you want.
Expand Down Expand Up @@ -342,7 +342,7 @@ def convert_to_kspace( # noqa: PLR0913
xr.DataArray: [description]
"""
coords = coords if coords else {}
coords.update(**kwargs)
coords.update(kwargs)
assert isinstance(coords, dict)
bounds = bounds if bounds else {}
arr = arr if isinstance(arr, xr.DataArray) else normalize_to_spectrum(arr)
Expand Down Expand Up @@ -434,7 +434,7 @@ class CoordinateTransform(TypedDict, total=True):

def convert_coordinates(
arr: xr.DataArray,
target_coordinates: dict[str, NDArray[np.float_]],
target_coordinates: KspaceCoords,
coordinate_transform: CoordinateTransform,
*,
as_dataset: bool = False,
Expand Down Expand Up @@ -558,8 +558,8 @@ def _chunk_convert(
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
resolution: dict[MOMENTUM, float] | None = None,
calibration: DetectorCalibration | None = None,
coords: dict[MOMENTUM, NDArray[np.float_]] | None = None,
**kwargs: NDArray[np.float_],
coords: KspaceCoords | None = None,
**kwargs: Unpack[KspaceCoords],
) -> xr.DataArray:
DESIRED_CHUNK_SIZE = 1000 * 1000 * 20
TOO_LARGE_CHUNK_SIZE = 100
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/utilities/conversion/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Hashable, TypeVar, Unpack
from typing import TYPE_CHECKING, TypeVar, Unpack

import numpy as np
import xarray as xr
Expand Down
16 changes: 7 additions & 9 deletions src/arpes/utilities/conversion/kx_ky_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import MOMENTUM
from arpes._typing import MOMENTUM, KspaceCoords

__all__ = ["ConvertKp", "ConvertKxKy"]

Expand Down Expand Up @@ -150,7 +150,7 @@ def get_coordinates(
self,
resolution: dict[MOMENTUM, float] | None = None,
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
) -> dict[str, NDArray[np.float_]]:
) -> KspaceCoords:
"""Calculates appropriate coordinate bounds.
Args:
Expand Down Expand Up @@ -309,12 +309,10 @@ def get_coordinates(
self,
resolution: dict[MOMENTUM, float] | None = None,
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
) -> dict[str, NDArray[np.float_]]:
) -> KspaceCoords:
"""Calculates appropriate coordinate bounds."""
if resolution is None:
resolution = {}
if bounds is None:
bounds = {}
resolution = resolution if resolution is not None else {}
bounds = bounds if bounds is not None else {}
coordinates = super().get_coordinates(resolution, bounds=bounds)
((kx_low, kx_high), (ky_low, ky_high)) = calculate_kx_ky_bounds(self.arr)
if "kx" in bounds:
Expand Down Expand Up @@ -353,8 +351,8 @@ def get_coordinates(
ky_high + K_SPACE_BORDER,
resolution.get("ky", inferred_ky_res),
)
base_coords = {
str(k): v # should v.values?
base_coords: KspaceCoords = {
str(k): v # should v.values?base
for k, v in self.arr.coords.items()
if k not in ["eV", "phi", "psi", "theta", "beta", "alpha", "chi"]
}
Expand Down
8 changes: 3 additions & 5 deletions src/arpes/utilities/conversion/kz_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import MOMENTUM
from arpes._typing import MOMENTUM, KspaceCoords

__all__ = ["ConvertKpKzV0", "ConvertKxKyKz", "ConvertKpKz"]

Expand Down Expand Up @@ -90,10 +90,8 @@ def get_coordinates(
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
) -> dict[str, NDArray[np.float_]]:
"""Calculates appropriate coordinate bounds."""
if resolution is None:
resolution = {}
if bounds is None:
bounds = {}
resolution = resolution if resolution is not None else {}
bounds = bounds if bounds is not None else {}
coordinates = super().get_coordinates(resolution=resolution, bounds=bounds)
((kp_low, kp_high), (kz_low, kz_high)) = calculate_kp_kz_bounds(self.arr)
if "kp" in bounds:
Expand Down

0 comments on commit 760852f

Please sign in to comment.