From 3b53fcc4165994d75fe92a5be797187ac9e9c3da Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Wed, 7 Feb 2024 10:28:00 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20UPdate=20type=20hints=20and?= =?UTF-8?q?=20be=20slim.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/_typing.py | 8 + arpes/utilities/conversion/core.py | 194 +++++++++++------------- arpes/utilities/conversion/forward.py | 44 ++---- arpes/utilities/conversion/trapezoid.py | 8 +- arpes/xarray_extensions.py | 2 +- 5 files changed, 115 insertions(+), 141 deletions(-) diff --git a/arpes/_typing.py b/arpes/_typing.py index a51a6b0d..1b23db75 100644 --- a/arpes/_typing.py +++ b/arpes/_typing.py @@ -80,6 +80,14 @@ ANGLE = Literal["alpha", "beta", "chi", "theta"] | EMISSION_ANGLE +class KspaceCoords(TypedDict, total=False): + eV: NDArray[np.float_] + kp: NDArray[np.float_] + kx: NDArray[np.float_] + ky: NDArray[np.float_] + kz: NDArray[np.float_] + + class ConfigSettings(TypedDict, total=False): """TypedDict for arpes.config.SETTINGS.""" diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py index 429be92c..45513005 100644 --- a/arpes/utilities/conversion/core.py +++ b/arpes/utilities/conversion/core.py @@ -33,10 +33,12 @@ import numpy as np import xarray as xr from scipy.interpolate import RegularGridInterpolator +from xarray.coding.cftime_offsets import MonthEnd from arpes.provenance import PROVENANCE, provenance, update_provenance -from arpes.trace import Trace, traceable from arpes.utilities import normalize_to_spectrum +from arpes.utilities.conversion.base import CoordinateConverter +from arpes.utilities.conversion.calibration import DetectorCalibration from .fast_interp import Interpolator from .grids import ( @@ -69,14 +71,12 @@ logger.propagate = False -@traceable def grid_interpolator_from_dataarray( arr: xr.DataArray, fill_value: float = 0.0, method: Literal["linear", "nearest", "slinear", "cubic", "quintic", "pchip"] = "linear", *, bounds_error: bool = False, - trace: Trace | None = None, ) -> RegularGridInterpolator | Interpolator: """Translates an xarray.DataArray contents into a scipy.interpolate.RegularGridInterpolator. @@ -89,7 +89,6 @@ def grid_interpolator_from_dataarray( if len(c) > 1 and c[1] - c[0] < 0: flip_axes.add(str(d)) values: NDArray[np.float_] = arr.values - trace("Flipping axes") if trace else None for dim in flip_axes: values = np.flip(values, arr.dims.index(dim)) interp_points = [ @@ -98,15 +97,8 @@ def grid_interpolator_from_dataarray( trace_size = [len(pts) for pts in interp_points] if method == "linear": - trace(f"Using fast_interp.Interpolator: size {trace_size}") if trace else None + logger.debug(f"Using fast_interp.Interpolator: size {trace_size}") return Interpolator.from_arrays(interp_points, values) - ( - trace( - f"Calling scipy.interpolate.RegularGridInterpolator: size {trace_size}", - ) - if trace - else None - ) return RegularGridInterpolator( points=interp_points, values=values, @@ -328,26 +320,24 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[ @update_provenance("Automatically k-space converted") -@traceable def convert_to_kspace( arr: xr.DataArray, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, - resolution: dict | None = None, - calibration: Incomplete | None = None, - coords: dict[Hashable, NDArray[np.float_]] | None = None, + resolution: dict[MOMENTUM, float] | None = None, + calibration: DetectorCalibration | None = None, + coords: dict[MOMENTUM, NDArray[np.float_]] | None = None, *, allow_chunks: bool = False, - trace: Trace | None = None, **kwargs: NDArray[np.float_], -) -> xr.DataArray | xr.Dataset | None: +) -> xr.DataArray: """Converts volumetric the data to momentum space ("backwards"). Typically what you want. Works in general by regridding the data into the new coordinate space and then interpolating back into the original data. For forward conversion, see sibling methods. Forward conversion works by - converting the coordinates, rather than by interpolating the data. As a result, the data will be totally unchanged by the conversion (if we do not apply a Jacobian correction), but the + converting the coordinates, rather than by interpolating the data. As a result, the data will be coordinates will no longer have equal spacing. This is only really useful for zero and one dimensional data because for two dimensional data, @@ -377,15 +367,13 @@ def convert_to_kspace( Args: arr (xr.DataArray): ARPES data - bounds (dict[MOMENTUM, tuple[float, float]] | None): The key is the axis name. - The value is the bounds. - Defaults to {}. - resolution ([type]): [description]. Defaults to None. - calibration ([type], optional): [description]. Defaults to None. - coords (dict[str, Iterable[float], optional): Coordinate of k-space. Defaults to {}. + bounds (dict[MOMENTUM, tuple[float, float]], optional): + The key is the axis name. The value is the bounds. Defaults to {}. + If not set this arg, set coords. + resolution (dict[Momentum, float], optional): dict for the energy/angular resolution. + calibration (DetectorCalibration, optional): DetectorCalibration object. Defaults to None. + coords (dict[Momentum, Iterable[float], optional): Coordinate of k-space. Defaults to {}. allow_chunks (bool): [description]. Defaults to False. - trace (Callable, optional): Controls whether to use execution tracing. Defaults to None. - Pass `True` to enable. **kwargs: treated as coords. Raises: @@ -400,8 +388,7 @@ def convert_to_kspace( coords = {} if bounds is None: bounds = {} - coords.update(kwargs) - trace("Normalizing to spectrum") if trace else None + coords.update(**kwargs) if isinstance(arr, xr.Dataset): msg = "Remember to use a DataArray not a Dataset, " msg += "attempting to extract spectrum and copy attributes." @@ -425,10 +412,8 @@ def convert_to_kspace( resolution=resolution, calibration=calibration, coords=coords, - trace=trace, **kwargs, ) - trace("Determining dimensions and resolution") if trace else None momentum_incompatibles: list[str] = [ str(d) for d in arr.dims if not is_dimension_convertible_to_mementum(str(d)) ] @@ -442,7 +427,6 @@ def convert_to_kspace( momentum_compatibles.sort() - trace("Replacing dummy coordinates with index-like ones.") if trace else None # temporarily reassign coordinates for dimensions we will not # convert to "index-like" dimensions restore_index_like_coordinates: dict[str, NDArray[np.float_]] = { @@ -469,37 +453,31 @@ def convert_to_kspace( # ('chi', 'phi',): ConvertKxKy, ("hv", "phi"): ConvertKpKz, }.get(tuple(momentum_compatibles)) - if convert_cls: - converter = convert_cls(arr, converted_dims, calibration=calibration) - trace("Converting coordinates") if trace else None - converted_coordinates: dict[Hashable, NDArray[np.float_]] = converter.get_coordinates( - resolution=resolution, - bounds=bounds, - ) - if not set(coords.keys()).issubset(converted_coordinates.keys()): - extra = set(coords.keys()).difference(converted_coordinates.keys()) - msg = f"Unexpected passed coordinates: {extra}" - raise ValueError(msg) - converted_coordinates.update(coords) - trace("Calling convert_coordinates") if trace else None - trace(f"converted_dims{converted_dims}") if trace else None - result = convert_coordinates( - arr, - converted_coordinates, - { - "dims": converted_dims, - "transforms": dict( - zip(arr.dims, [converter.conversion_for(dim) for dim in arr.dims], strict=True), - ), - }, - trace=trace, - ) - trace("Reassigning index-like coordinates.") if trace else None - result = result.assign_coords(restore_index_like_coordinates) - trace("Finished.") if trace else None - return result - msg = "Cannot select convert class" - raise RuntimeError(msg) + assert convert_cls is not None, "Cannot select convert class" + + converter = convert_cls(arr, converted_dims, calibration=calibration) + + converted_coordinates: dict[Hashable, NDArray[np.float_]] = converter.get_coordinates( + resolution=resolution, + bounds=bounds, + ) + if not set(coords.keys()).issubset(converted_coordinates.keys()): + extra = set(coords.keys()).difference(converted_coordinates.keys()) + msg = f"Unexpected passed coordinates: {extra}" + raise ValueError(msg) + converted_coordinates.update(coords) + result = convert_coordinates( + arr, + converted_coordinates, + { + "dims": converted_dims, + "transforms": dict( + zip(arr.dims, [converter.conversion_for(dim) for dim in arr.dims], strict=True), + ), + }, + ) + assert isinstance(result, xr.DataArray) + return result.assign_coords(restore_index_like_coordinates) class TRANSFORMCOORDS(TypedDict, total=False): @@ -507,58 +485,49 @@ class TRANSFORMCOORDS(TypedDict, total=False): transforms: NDArray[np.float_] -@traceable def convert_coordinates( arr: xr.DataArray, - target_coordinates: dict[Hashable, NDArray[np.float_] | xr.DataArray], + target_coordinates: dict[Hashable, NDArray[np.float_]], coordinate_transform: dict[Hashable, NDArray[np.float_]], *, as_dataset: bool = False, - trace: Trace | None = None, ) -> xr.DataArray | xr.Dataset: """Return Band structure data (converted to k-space). Args: arr(xr.DataArray): ARPES data - target_coordinates:(dict[Hashable, NDArray[np.float_] | xr.DataArray]): coorrdinate for ... + target_coordinates:(dict[Hashable, NDArray[np.float_]]): coorrdinate for ... coordinate_transform(dict[Hashable, list[str] | Callable]): coordinat for ... as_dataset(bool): if True, return the data as the dataSet - trace(Callable): if True, trace command is activated. Returns: xr.DataArray | xr.Dataset """ assert isinstance(arr, xr.DataArray) ordered_source_dimensions = arr.dims - trace("Instantiating grid interpolator.") if trace else None + grid_interpolator = grid_interpolator_from_dataarray( arr.transpose(*ordered_source_dimensions), fill_value=float("nan"), - trace=trace, ) - trace("Finished instantiating grid interpolator.") if trace else None # Skip the Jacobian correction for now # Convert the raw coordinate axes to a set of gridded points - if trace: - trace(f"meshgrid: {[len(target_coordinates[_]) for _ in coordinate_transform['dims']]}") + logger.debug(f"meshgrid: {[len(target_coordinates[_]) for _ in coordinate_transform['dims']]}") meshed_coordinates = np.meshgrid( *[target_coordinates[dim] for dim in coordinate_transform["dims"]], indexing="ij", ) - trace("Raveling coordinates") if trace else None meshed_coordinates = [meshed_coord.ravel() for meshed_coord in meshed_coordinates] if "eV" not in arr.dims: with contextlib.suppress(ValueError): meshed_coordinates = [arr.S.lookup_offset_coord("eV"), *meshed_coordinates] - old_coord_names = [dim for dim in arr.dims if dim not in target_coordinates] old_coordinate_transforms = [ coordinate_transform["transforms"][dim] for dim in arr.dims if dim not in target_coordinates ] - trace("Calling coordinate transforms") if trace else None output_shape = [len(target_coordinates[d]) for d in coordinate_transform["dims"]] def compute_coordinate(transform: Callable[..., NDArray[np.float_]]) -> NDArray[np.float_]: @@ -568,23 +537,11 @@ def compute_coordinate(transform: Callable[..., NDArray[np.float_]]) -> NDArray[ order="C", ) - old_dimensions = [] - for tr in old_coordinate_transforms: - trace(f"Running transform {tr}") if trace else None - old_dimensions.append(compute_coordinate(tr)) - - trace("Done running transforms.") if trace else None + old_dimensions = [compute_coordinate(tr) for tr in old_coordinate_transforms] ordered_transformations = [coordinate_transform["transforms"][dim] for dim in arr.dims] - trace("Calling grid interpolator") if trace else None + transformed_coordinates = [tr(*meshed_coordinates) for tr in ordered_transformations] - trace("Pulling back coordinates") if trace else None - transformed_coordinates = [] - for tr in ordered_transformations: - trace(f"Running transform {tr}") if trace else None - transformed_coordinates.append(tr(*meshed_coordinates)) - - trace("Calling grid interpolator") if trace else None if not isinstance(grid_interpolator, Interpolator): converted_volume = grid_interpolator(np.array(transformed_coordinates).T) else: @@ -592,15 +549,23 @@ def compute_coordinate(transform: Callable[..., NDArray[np.float_]]) -> NDArray[ # Wrap it all up def acceptable_coordinate(c: NDArray[np.float_] | xr.DataArray) -> bool: - # Currently we do this to filter out coordinates - # that are functions of the old angular dimensions, - # we could forward convert these, but right now we do not + """[TODO:summary]. + + Currently we do this to filter out coordinates + that are functions of the old angular dimensions, + we could forward convert these, but right now we do not + + Args: + c: [TODO:description] + + Returns: + [TODO:description] + """ try: return bool(set(c.dims).issubset(coordinate_transform["dims"])) except AttributeError: return True - trace("Bundling into DataArray") if trace else None target_coordinates = {k: v for k, v in target_coordinates.items() if acceptable_coordinate(v)} data = xr.DataArray( np.reshape( @@ -618,35 +583,45 @@ def acceptable_coordinate(c: NDArray[np.float_] | xr.DataArray) -> bool: ] if as_dataset: variables = {"data": data} - variables.update(dict(zip(old_coord_names, old_mapped_coords, strict=True))) + variables.update( + dict( + zip( + old_coord_names, + old_mapped_coords, + strict=True, + ), + ), + ) return xr.Dataset(variables, attrs=arr.attrs) - trace("Finished: convert_coordinates") if trace else None return data -@traceable def _chunk_convert( arr: xr.DataArray, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, - resolution: dict | None = None, - calibration: Incomplete | None = None, - coords: dict[Hashable, NDArray[np.float_]] | None = None, - *, - trace: Trace | None, + resolution: dict[MOMENTUM, float] | None = None, + calibration: DetectorCalibration | None = None, + coords: dict[MOMENTUM, NDArray[np.float_]] | None = None, **kwargs: NDArray[np.float_], ) -> xr.DataArray: DESIRED_CHUNK_SIZE = 1000 * 1000 * 20 TOO_LARGE_CHUNK_SIZE = 100 n_chunks: np.int_ = np.prod(arr.shape) // DESIRED_CHUNK_SIZE if n_chunks == 0: - warnings.warn("Data size is sufficiently small, set allow_chunks=False", stacklevel=2) + warnings.warn( + "Data size is sufficiently small, set allow_chunks=False", + stacklevel=2, + ) n_chunks += 1 if n_chunks > TOO_LARGE_CHUNK_SIZE: - warnings.warn("Input array is very large. Please consider resampling.", stacklevel=2) + warnings.warn( + "Input array is very large. Please consider resampling.", + stacklevel=2, + ) chunk_thickness = np.max(len(arr.eV) // n_chunks, 1) - trace(f"Chunking along energy: {n_chunks}, thickness {chunk_thickness}") if trace else None + logger.debug(f"Chunking along energy: {n_chunks}, thickness {chunk_thickness}") finished = [] low_idx = 0 high_idx = chunk_thickness @@ -661,18 +636,23 @@ def _chunk_convert( calibration=calibration, coords=coords, allow_chunks=False, - trace=trace, **kwargs, ) if "eV" not in kchunk.dims: kchunk = kchunk.expand_dims("eV") + assert isinstance(kchunk, xr.DataArray) finished.append(kchunk) low_idx = high_idx high_idx = min(len(arr.eV), high_idx + chunk_thickness) return xr.concat(finished, dim="eV") -def _extract_symmetry_point(name: str, arr: xr.DataArray, *, extend_to_edge: bool = False) -> dict: +def _extract_symmetry_point( + name: str, + arr: xr.DataArray, + *, + extend_to_edge: bool = False, +) -> dict: """[TODO:summary]. Args: diff --git a/arpes/utilities/conversion/forward.py b/arpes/utilities/conversion/forward.py index 1e245285..4b5111c6 100644 --- a/arpes/utilities/conversion/forward.py +++ b/arpes/utilities/conversion/forward.py @@ -14,7 +14,7 @@ import warnings from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Unpack import numpy as np import xarray as xr @@ -37,7 +37,7 @@ from numpy.typing import NDArray - from arpes._typing import DataType + from arpes._typing import DataType, KspaceCoords __all__ = ( "convert_coordinates_to_kspace_forward", @@ -61,13 +61,10 @@ logger.propagate = False -@traceable def convert_coordinate_forward( data: DataType, coords: dict[str, float], - *, - trace: Trace | None = None, - **k_coords: NDArray[np.float_], + **k_coords: Unpack[KspaceCoords], ) -> dict[str, float]: """Inverse/forward transform for the small angle volumetric k-conversion code. @@ -100,7 +97,6 @@ def convert_coordinate_forward( Args: data (DataType): The data defining the coordinate offsets and experiment geometry. coords (dict[str, float]): The coordinates of a *point* in angle-space to be converted. - trace: Used for performance tracing and debugging. k_coords: Coordinate for k-axis Returns: @@ -125,29 +121,19 @@ def convert_coordinate_forward( "ky": np.linspace(-4, 4, 300), } # Copying after taking a constant energy plane is much much cheaper - trace("Copying") if trace else None data_arr = data_arr.copy(deep=True) data_arr.loc[data_arr.G.round_coordinates(coords)] = data_arr.values.max() * 100000 - trace("Filtering") if trace else None data_arr = gaussian_filter_arr(data_arr, default_size=3) - - trace("Converting once") if trace else None - kdata = convert_to_kspace(data_arr, trace=trace, **k_coords) - - trace("argmax") if trace else None + kdata = convert_to_kspace(data_arr, **k_coords) near_target = kdata.G.argmax_coords() - - trace("Converting twice") if trace else None kdata_close = convert_to_kspace( data_arr, - trace=trace, **{k: np.linspace(v - 0.08, v + 0.08, 100) for k, v in near_target.items()}, ) # inconsistently, the energy coordinate is sometimes returned here # so we remove it just in case - trace("argmax") if trace else None coords = kdata_close.G.argmax_coords() if "eV" in coords: del coords["eV"] @@ -202,8 +188,8 @@ def convert_through_angular_pair( # noqa: PLR0913 Returns: The momentum cut passing first through `first_point` and then through `second_point`. """ - k_first_point = convert_coordinate_forward(data, first_point, trace=trace, **k_coords) - k_second_point = convert_coordinate_forward(data, second_point, trace=trace, **k_coords) + k_first_point = convert_coordinate_forward(data, first_point, **k_coords) + k_second_point = convert_coordinate_forward(data, second_point, **k_coords) k_dims = set(k_first_point.keys()) if k_dims != {"kx", "ky"}: @@ -218,13 +204,13 @@ def convert_through_angular_pair( # noqa: PLR0913 k_second_point["ky"] - k_first_point["ky"], k_second_point["kx"] - k_first_point["kx"], ) - trace(f"Determined offset angle {-offset_ang}") if trace else None + logger.debug(f"Determined offset angle {-offset_ang}") with data.S.with_rotation_offset(-offset_ang): - trace("Finding first momentum coordinate.") if trace else None - k_first_point = convert_coordinate_forward(data, first_point, trace=trace, **k_coords) - trace("Finding second momentum coordinate.") if trace else None - k_second_point = convert_coordinate_forward(data, second_point, trace=trace, **k_coords) + logger.debug("Finding first momentum coordinate.") + k_first_point = convert_coordinate_forward(data, first_point, **k_coords) + logger.debug("Finding second momentum coordinate.") + k_second_point = convert_coordinate_forward(data, second_point, **k_coords) # adjust output coordinate ranges transverse_specification = { @@ -249,7 +235,6 @@ def convert_through_angular_pair( # noqa: PLR0913 data, **transverse_specification, kx=parallel_axis, - trace=trace, ).mean(list(transverse_specification.keys())) trace("Annotating the requested point momentum values.") if trace else None @@ -295,7 +280,11 @@ def convert_through_angular_point( # noqa: PLR0913 Returns: A momentum cut passing through the point `coords`. """ - k_coords = convert_coordinate_forward(data, coords, trace=trace, **k_coords) + k_coords = convert_coordinate_forward( + data, + coords, + **k_coords, + ) all_momentum_dims = set(k_coords.keys()) assert all_momentum_dims == set(cut_specification.keys()).union(transverse_specification.keys()) @@ -309,7 +298,6 @@ def convert_through_angular_point( # noqa: PLR0913 data, **transverse_specification, **cut_specification, - trace=trace, ).mean(list(transverse_specification.keys()), keep_attrs=True) for k, v in k_coords.items(): diff --git a/arpes/utilities/conversion/trapezoid.py b/arpes/utilities/conversion/trapezoid.py index d8e1e97e..cc0026ed 100644 --- a/arpes/utilities/conversion/trapezoid.py +++ b/arpes/utilities/conversion/trapezoid.py @@ -127,7 +127,7 @@ def __init__( right_phi_one_volt, ) - def get_coordinates(self, *args: Incomplete, **kwargs: Incomplete) -> Indexes: + def get_coordinates(self, *args: Incomplete, **kwargs: Incomplete) -> Indexes: # TODO: rename ! if args: logger.debug("ConvertTrapezoidalCorrection.get_coordinates: args is not used but set.") if kwargs: @@ -138,7 +138,7 @@ def get_coordinates(self, *args: Incomplete, **kwargs: Incomplete) -> Indexes: return self.arr.indexes - def conversion_for(self, dim: str) -> Callable: + def conversion_for(self, dim: str) -> Callable[..., NDArray[np.float_]]: def with_identity(*args: Incomplete) -> NDArray[np.float_]: return self.identity_transform(dim, *args) @@ -207,8 +207,6 @@ def apply_trapezoidal_correction( Returns: The corrected data. """ - trace("Normalizing to spectrum") if trace else None - if isinstance(data, dict): warnings.warn( "Treating dict-like data as an attempt to forward convert a single coordinate.", @@ -261,7 +259,7 @@ def apply_trapezoidal_correction( }, trace=trace, ) - + assert isinstance(result, xr.DataArray) trace("Reassigning index-like coordinates.") if trace else None result = result.assign_coords(restore_index_like_coordinates) result = result.assign_coords( diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 843cb8de..a8809ea1 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -2294,7 +2294,7 @@ def round_coordinates( return rounded - def argmax_coords(self) -> dict: + def argmax_coords(self) -> dict[Hashable, float]: """Return dict representing the position for maximum value.""" assert isinstance(self._obj, xr.DataArray) data: xr.DataArray = self._obj