diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py index 2378c963..dc20e3b2 100644 --- a/arpes/utilities/conversion/core.py +++ b/arpes/utilities/conversion/core.py @@ -36,7 +36,6 @@ from arpes.provenance import PROVENANCE, provenance, update_provenance from arpes.utilities import normalize_to_spectrum -from arpes.utilities.conversion.calibration import DetectorCalibration from .fast_interp import Interpolator from .grids import ( @@ -51,6 +50,7 @@ from numpy.typing import NDArray from arpes._typing import MOMENTUM + from arpes.utilities.conversion.calibration import DetectorCalibration __all__ = ["convert_to_kspace", "slice_along_path"] @@ -477,15 +477,10 @@ def convert_to_kspace( return result.assign_coords(restore_index_like_coordinates) -class TRANSFORMCOORDS(TypedDict, total=False): - dims: list[str] - transforms: NDArray[np.float_] - - def convert_coordinates( arr: xr.DataArray, target_coordinates: dict[Hashable, NDArray[np.float_]], - coordinate_transform: dict[Hashable, NDArray[np.float_]], + coordinate_transform: dict[Hashable, list[str] | dict[Hashable, NDArray[np.float_]]], *, as_dataset: bool = False, ) -> xr.DataArray | xr.Dataset: @@ -521,8 +516,10 @@ def convert_coordinates( with contextlib.suppress(ValueError): meshed_coordinates = [arr.S.lookup_offset_coord("eV"), *meshed_coordinates] old_coord_names = [dim for dim in arr.dims if dim not in target_coordinates] + assert isinstance(coordinate_transform["transforms"], dict) + transforms: dict[Hashable, NDArray[np.float_]] = coordinate_transform["transforms"] old_coordinate_transforms = [ - coordinate_transform["transforms"][dim] for dim in arr.dims if dim not in target_coordinates + transforms[dim] for dim in arr.dims if dim not in target_coordinates ] output_shape = [len(target_coordinates[d]) for d in coordinate_transform["dims"]]