diff --git a/src/arpes/utilities/conversion/base.py b/src/arpes/utilities/conversion/base.py index 706775ab..07605342 100644 --- a/src/arpes/utilities/conversion/base.py +++ b/src/arpes/utilities/conversion/base.py @@ -14,7 +14,7 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from arpes._typing import MOMENTUM + from arpes._typing import MOMENTUM, KspaceCoords from .calibration import DetectorCalibration @@ -150,7 +150,7 @@ def get_coordinates( self, resolution: dict[MOMENTUM, float] | None = None, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, - ) -> dict[str, NDArray[np.float_]]: + ) -> KspaceCoords: """Calculates the coordinates which should be used in momentum space. Args: @@ -159,12 +159,10 @@ def get_coordinates( bounds(dict, optional): bounds of the momentum coordinates Returns: - dict[str, NDArray]: the key represents the axis name suchas "kp", "kx", and "eV". + KspaceCoords: the key represents the axis name suchas "kp", "kx", and "eV". """ - if resolution is None: - resolution = {} - if bounds is None: - bounds = {} - coordinates: dict[str, NDArray[np.float_]] = {} + resolution = resolution if resolution is not None else {} + bounds = bounds if bounds is not None else {} + coordinates: KspaceCoords = {} coordinates["eV"] = self.arr.coords["eV"].values return coordinates diff --git a/src/arpes/utilities/conversion/core.py b/src/arpes/utilities/conversion/core.py index ac96bd32..830463e3 100644 --- a/src/arpes/utilities/conversion/core.py +++ b/src/arpes/utilities/conversion/core.py @@ -28,7 +28,7 @@ from collections.abc import Hashable from itertools import pairwise from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import numpy as np import xarray as xr @@ -52,7 +52,7 @@ from numpy.typing import NDArray - from arpes._typing import MOMENTUM, XrTypes + from arpes._typing import MOMENTUM, KspaceCoords, XrTypes from arpes.utilities.conversion.calibration import DetectorCalibration __all__ = ["convert_to_kspace", "slice_along_path"] @@ -284,9 +284,9 @@ def convert_to_kspace( # noqa: PLR0913 bounds: dict[MOMENTUM, tuple[float, float]] | None = None, resolution: dict[MOMENTUM, float] | None = None, calibration: DetectorCalibration | None = None, - coords: dict[MOMENTUM, NDArray[np.float_]] | None = None, + coords: KspaceCoords | None = None, allow_chunks: bool = False, - **kwargs: NDArray[np.float_], + **kwargs: Unpack[KspaceCoords], ) -> xr.DataArray: """Converts volumetric the data to momentum space ("backwards"). Typically what you want. @@ -342,7 +342,7 @@ def convert_to_kspace( # noqa: PLR0913 xr.DataArray: [description] """ coords = coords if coords else {} - coords.update(**kwargs) + coords.update(kwargs) assert isinstance(coords, dict) bounds = bounds if bounds else {} arr = arr if isinstance(arr, xr.DataArray) else normalize_to_spectrum(arr) @@ -434,7 +434,7 @@ class CoordinateTransform(TypedDict, total=True): def convert_coordinates( arr: xr.DataArray, - target_coordinates: dict[str, NDArray[np.float_]], + target_coordinates: KspaceCoords, coordinate_transform: CoordinateTransform, *, as_dataset: bool = False, @@ -558,8 +558,8 @@ def _chunk_convert( bounds: dict[MOMENTUM, tuple[float, float]] | None = None, resolution: dict[MOMENTUM, float] | None = None, calibration: DetectorCalibration | None = None, - coords: dict[MOMENTUM, NDArray[np.float_]] | None = None, - **kwargs: NDArray[np.float_], + coords: KspaceCoords | None = None, + **kwargs: Unpack[KspaceCoords], ) -> xr.DataArray: DESIRED_CHUNK_SIZE = 1000 * 1000 * 20 TOO_LARGE_CHUNK_SIZE = 100 diff --git a/src/arpes/utilities/conversion/forward.py b/src/arpes/utilities/conversion/forward.py index 5ad06c41..c7a83bac 100644 --- a/src/arpes/utilities/conversion/forward.py +++ b/src/arpes/utilities/conversion/forward.py @@ -14,7 +14,7 @@ import warnings from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Hashable, TypeVar, Unpack +from typing import TYPE_CHECKING, TypeVar, Unpack import numpy as np import xarray as xr diff --git a/src/arpes/utilities/conversion/kx_ky_conversion.py b/src/arpes/utilities/conversion/kx_ky_conversion.py index 209bf169..565c6e94 100644 --- a/src/arpes/utilities/conversion/kx_ky_conversion.py +++ b/src/arpes/utilities/conversion/kx_ky_conversion.py @@ -25,7 +25,7 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from arpes._typing import MOMENTUM + from arpes._typing import MOMENTUM, KspaceCoords __all__ = ["ConvertKp", "ConvertKxKy"] @@ -150,7 +150,7 @@ def get_coordinates( self, resolution: dict[MOMENTUM, float] | None = None, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, - ) -> dict[str, NDArray[np.float_]]: + ) -> KspaceCoords: """Calculates appropriate coordinate bounds. Args: @@ -309,12 +309,10 @@ def get_coordinates( self, resolution: dict[MOMENTUM, float] | None = None, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, - ) -> dict[str, NDArray[np.float_]]: + ) -> KspaceCoords: """Calculates appropriate coordinate bounds.""" - if resolution is None: - resolution = {} - if bounds is None: - bounds = {} + resolution = resolution if resolution is not None else {} + bounds = bounds if bounds is not None else {} coordinates = super().get_coordinates(resolution, bounds=bounds) ((kx_low, kx_high), (ky_low, ky_high)) = calculate_kx_ky_bounds(self.arr) if "kx" in bounds: @@ -353,8 +351,8 @@ def get_coordinates( ky_high + K_SPACE_BORDER, resolution.get("ky", inferred_ky_res), ) - base_coords = { - str(k): v # should v.values? + base_coords: KspaceCoords = { + str(k): v # should v.values?base for k, v in self.arr.coords.items() if k not in ["eV", "phi", "psi", "theta", "beta", "alpha", "chi"] } diff --git a/src/arpes/utilities/conversion/kz_conversion.py b/src/arpes/utilities/conversion/kz_conversion.py index c41f4af4..a34d6320 100644 --- a/src/arpes/utilities/conversion/kz_conversion.py +++ b/src/arpes/utilities/conversion/kz_conversion.py @@ -20,7 +20,7 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from arpes._typing import MOMENTUM + from arpes._typing import MOMENTUM, KspaceCoords __all__ = ["ConvertKpKzV0", "ConvertKxKyKz", "ConvertKpKz"] @@ -90,10 +90,8 @@ def get_coordinates( bounds: dict[MOMENTUM, tuple[float, float]] | None = None, ) -> dict[str, NDArray[np.float_]]: """Calculates appropriate coordinate bounds.""" - if resolution is None: - resolution = {} - if bounds is None: - bounds = {} + resolution = resolution if resolution is not None else {} + bounds = bounds if bounds is not None else {} coordinates = super().get_coordinates(resolution=resolution, bounds=bounds) ((kp_low, kp_high), (kz_low, kz_high)) = calculate_kp_kz_bounds(self.arr) if "kp" in bounds: