From d9c9af7f8fd1539c705329e4883d3755f389fe31 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 9 Feb 2024 20:38:10 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/utilities/conversion/base.py | 14 +++++++++++++- arpes/utilities/conversion/bounds_calculations.py | 6 +++--- arpes/utilities/conversion/core.py | 14 +++++++++++--- arpes/xarray_extensions.py | 12 ++++++------ 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/arpes/utilities/conversion/base.py b/arpes/utilities/conversion/base.py index b43de710..706775ab 100644 --- a/arpes/utilities/conversion/base.py +++ b/arpes/utilities/conversion/base.py @@ -34,7 +34,19 @@ K_SPACE_BORDER = 0.02 -MOMENTUM_BREAKPOINTS = [0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1] +MOMENTUM_BREAKPOINTS: list[float] = [ + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, +] class CoordinateConverter: diff --git a/arpes/utilities/conversion/bounds_calculations.py b/arpes/utilities/conversion/bounds_calculations.py index 06815550..f987a5b4 100644 --- a/arpes/utilities/conversion/bounds_calculations.py +++ b/arpes/utilities/conversion/bounds_calculations.py @@ -253,7 +253,7 @@ def calculate_kp_bounds(arr: xr.DataArray) -> tuple[float, float]: def calculate_kx_ky_bounds( arr: xr.DataArray, -) -> tuple[tuple[np.float_, np.float_], tuple[np.float_, np.float_]]: +) -> tuple[tuple[float, float], tuple[float, float]]: """Calculates the kx and ky range for a dataset with a fixed photon energy. This is used to infer the gridding that should be used for a k-space conversion. @@ -316,6 +316,6 @@ def calculate_kx_ky_bounds( * np.sin(sampled_beta_values) ) return ( - (round(np.min(kxs), 2), round(np.max(kxs), 2)), - (round(np.min(kys), 2), round(np.max(kys), 2)), + (round(np.min(kxs), 2).astype(float), round(np.max(kxs), 2).astype(float)), + (round(np.min(kys), 2).astype(float), round(np.max(kys), 2).astype(float)), ) diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py index a44477c0..9e22612e 100644 --- a/arpes/utilities/conversion/core.py +++ b/arpes/utilities/conversion/core.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Literal import numpy as np +from numpy.typing import ArrayLike import xarray as xr from scipy.interpolate import RegularGridInterpolator @@ -114,7 +115,7 @@ def slice_along_path( # noqa: PLR0913 *, extend_to_edge: bool = False, shift_gamma: bool = True, -) -> xr.DataArray: +) -> xr.Dataset: """Gets a cut along a path specified by waypoints in an array. TODO: There might be a little bug here where the last coordinate has a value of 0, @@ -215,7 +216,10 @@ def slice_along_path( # noqa: PLR0913 path_segments = list(pairwise(parsed_interpolation_points)) - def required_sampling_density(waypoint_a: Mapping, waypoint_b: Mapping) -> float: + def required_sampling_density( + waypoint_a: Mapping[Hashable, float], + waypoint_b: Mapping[Hashable, float], + ) -> float: ks = waypoint_a.keys() dist = _element_distance(waypoint_a, waypoint_b) delta = np.array([waypoint_a[k] - waypoint_b[k] for k in ks]) @@ -295,6 +299,7 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[ }, as_dataset=True, ) + assert isinstance(converted_ds, xr.Dataset) if ( axis_name in arr.dims and len(parsed_interpolation_points) == 2 # noqa: PLR2004 @@ -697,6 +702,9 @@ def _extract_symmetry_point( return dict(zip([d for d in arr.dims if d in raw_point], S, strict=False)) -def _element_distance(waypoint_a: Mapping, waypoint_b: Mapping) -> np.float_: +def _element_distance( + waypoint_a: Mapping[Hashable, float], + waypoint_b: Mapping[Hashable, float], +) -> np.float_: delta = np.array([waypoint_a[k] - waypoint_b[k] for k in waypoint_a]) return np.linalg.norm(delta) diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index e61cdb53..badc07b8 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -82,7 +82,7 @@ from .utilities.xarray import unwrap_xarray_item if TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Callable, Generator, Iterator import lmfit from _typeshed import Incomplete @@ -145,7 +145,7 @@ def _iter_groups( grouped: dict[str, Sequence[float] | float], -) -> Generator[tuple[str, float], None, None]: +) -> Iterator[tuple[str, float]]: """Iterates through a flattened sequence. Sequentially yields keys and values from each sequence associated with a key. @@ -723,17 +723,17 @@ def symmetry_points( return self._calculate_symmetry_points(symmetry_points, **kwargs) @property - def iter_own_symmetry_points(self) -> Generator[tuple[str, float], None, None]: + def iter_own_symmetry_points(self) -> Iterator[tuple[str, float]]: sym_points, _ = self.symmetry_points() return _iter_groups(sym_points) @property - def iter_projected_symmetry_points(self) -> Generator[tuple[str, float], None, None]: + def iter_projected_symmetry_points(self) -> Iterator[tuple[str, float]]: _, sym_points = self.symmetry_points() return _iter_groups(sym_points) @property - def iter_symmetry_points(self) -> Generator[tuple[str, float], None, None]: + def iter_symmetry_points(self) -> Iterator[tuple[str, float]]: yield from self.iter_own_symmetry_points yield from self.iter_projected_symmetry_points @@ -2759,7 +2759,7 @@ def enumerate_iter_coords( def iter_coords( self, dim_names: tuple[str | Hashable, ...] = (), - ) -> Generator[dict[Hashable, float], None, None]: + ) -> Iterator[dict[Hashable, float]]: """[TODO:summary]. Args: