Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 9, 2024
1 parent 31d0cb9 commit d9c9af7
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
14 changes: 13 additions & 1 deletion arpes/utilities/conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,19 @@


K_SPACE_BORDER = 0.02
MOMENTUM_BREAKPOINTS = [0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]
MOMENTUM_BREAKPOINTS: list[float] = [
0.0005,
0.001,
0.002,
0.005,
0.01,
0.02,
0.05,
0.1,
0.2,
0.5,
1,
]


class CoordinateConverter:
Expand Down
6 changes: 3 additions & 3 deletions arpes/utilities/conversion/bounds_calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def calculate_kp_bounds(arr: xr.DataArray) -> tuple[float, float]:

def calculate_kx_ky_bounds(
arr: xr.DataArray,
) -> tuple[tuple[np.float_, np.float_], tuple[np.float_, np.float_]]:
) -> tuple[tuple[float, float], tuple[float, float]]:
"""Calculates the kx and ky range for a dataset with a fixed photon energy.
This is used to infer the gridding that should be used for a k-space conversion.
Expand Down Expand Up @@ -316,6 +316,6 @@ def calculate_kx_ky_bounds(
* np.sin(sampled_beta_values)
)
return (
(round(np.min(kxs), 2), round(np.max(kxs), 2)),
(round(np.min(kys), 2), round(np.max(kys), 2)),
(round(np.min(kxs), 2).astype(float), round(np.max(kxs), 2).astype(float)),
(round(np.min(kys), 2).astype(float), round(np.max(kys), 2).astype(float)),
)
14 changes: 11 additions & 3 deletions arpes/utilities/conversion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import TYPE_CHECKING, Literal

import numpy as np
from numpy.typing import ArrayLike
import xarray as xr
from scipy.interpolate import RegularGridInterpolator

Expand Down Expand Up @@ -114,7 +115,7 @@ def slice_along_path( # noqa: PLR0913
*,
extend_to_edge: bool = False,
shift_gamma: bool = True,
) -> xr.DataArray:
) -> xr.Dataset:
"""Gets a cut along a path specified by waypoints in an array.
TODO: There might be a little bug here where the last coordinate has a value of 0,
Expand Down Expand Up @@ -215,7 +216,10 @@ def slice_along_path( # noqa: PLR0913

path_segments = list(pairwise(parsed_interpolation_points))

def required_sampling_density(waypoint_a: Mapping, waypoint_b: Mapping) -> float:
def required_sampling_density(
waypoint_a: Mapping[Hashable, float],
waypoint_b: Mapping[Hashable, float],
) -> float:
ks = waypoint_a.keys()
dist = _element_distance(waypoint_a, waypoint_b)
delta = np.array([waypoint_a[k] - waypoint_b[k] for k in ks])
Expand Down Expand Up @@ -295,6 +299,7 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[
},
as_dataset=True,
)
assert isinstance(converted_ds, xr.Dataset)

if (
axis_name in arr.dims and len(parsed_interpolation_points) == 2 # noqa: PLR2004
Expand Down Expand Up @@ -697,6 +702,9 @@ def _extract_symmetry_point(
return dict(zip([d for d in arr.dims if d in raw_point], S, strict=False))


def _element_distance(waypoint_a: Mapping, waypoint_b: Mapping) -> np.float_:
def _element_distance(
waypoint_a: Mapping[Hashable, float],
waypoint_b: Mapping[Hashable, float],
) -> np.float_:
delta = np.array([waypoint_a[k] - waypoint_b[k] for k in waypoint_a])
return np.linalg.norm(delta)
12 changes: 6 additions & 6 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
from .utilities.xarray import unwrap_xarray_item

if TYPE_CHECKING:
from collections.abc import Callable, Generator
from collections.abc import Callable, Generator, Iterator

import lmfit
from _typeshed import Incomplete
Expand Down Expand Up @@ -145,7 +145,7 @@

def _iter_groups(
grouped: dict[str, Sequence[float] | float],
) -> Generator[tuple[str, float], None, None]:
) -> Iterator[tuple[str, float]]:
"""Iterates through a flattened sequence.
Sequentially yields keys and values from each sequence associated with a key.
Expand Down Expand Up @@ -723,17 +723,17 @@ def symmetry_points(
return self._calculate_symmetry_points(symmetry_points, **kwargs)

@property
def iter_own_symmetry_points(self) -> Generator[tuple[str, float], None, None]:
def iter_own_symmetry_points(self) -> Iterator[tuple[str, float]]:
sym_points, _ = self.symmetry_points()
return _iter_groups(sym_points)

@property
def iter_projected_symmetry_points(self) -> Generator[tuple[str, float], None, None]:
def iter_projected_symmetry_points(self) -> Iterator[tuple[str, float]]:
_, sym_points = self.symmetry_points()
return _iter_groups(sym_points)

@property
def iter_symmetry_points(self) -> Generator[tuple[str, float], None, None]:
def iter_symmetry_points(self) -> Iterator[tuple[str, float]]:
yield from self.iter_own_symmetry_points
yield from self.iter_projected_symmetry_points

Expand Down Expand Up @@ -2759,7 +2759,7 @@ def enumerate_iter_coords(
def iter_coords(
self,
dim_names: tuple[str | Hashable, ...] = (),
) -> Generator[dict[Hashable, float], None, None]:
) -> Iterator[dict[Hashable, float]]:
"""[TODO:summary].
Args:
Expand Down

0 comments on commit d9c9af7

Please sign in to comment.