Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 19, 2024
1 parent 5d9f289 commit 223f7a8
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 26 deletions.
5 changes: 2 additions & 3 deletions src/arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ def fit_for_effective_mass(
We should probably include uncertainties here.
Args:
data (DataType): ARPES data
data (xr.DataArray): ARPES data
fit_kwargs: Passthrough for arguments to `broadcast_model`, used internally to
obtain the Lorentzian peak locations
Returns:
The effective mass in units of the bare mass.
"""
if fit_kwargs is None:
fit_kwargs = {}
fit_kwargs = fit_kwargs if fit_kwargs is not None else {}

mom_dim = next(
dim for dim in ["kp", "kx", "ky", "kz", "phi", "beta", "theta"] if dim in data.dims
Expand Down
37 changes: 27 additions & 10 deletions src/arpes/utilities/conversion/bounds_calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,14 @@ def calculate_kp_kz_bounds(arr: xr.DataArray) -> tuple[tuple[float, float], tupl


def calculate_kp_bounds(arr: xr.DataArray) -> tuple[float, float]:
"""Calculates kp bounds for a single ARPES cut."""
"""Calculates kp bounds for a single ARPES cut.
Args:
arr (xr.DataArray): ARPES 'cut'-type (the number of the anglar axis is 1 ("phi")) data
Returns (tuple[float, float]):
Minimum and maximum value of K region from the ARPES data
"""
phi_coords = arr.coords["phi"].values - arr.S.phi_offset
beta = float(arr.coords["beta"]) - arr.S.beta_offset

Expand All @@ -233,18 +240,18 @@ def calculate_kp_bounds(arr: xr.DataArray) -> tuple[float, float]:

if arr.S.energy_notation == "Binding":
max_kinetic_energy = max(
arr.coords["eV"].values.max(),
arr.coords["eV"].max().item(),
arr.S.hv - arr.S.analyzer_work_function,
)
elif arr.S.energy_notation == "Kinetic":
max_kinetic_energy = max(arr.coords["eV"].values.max(), 0 - arr.S.analyzer_work_function)
max_kinetic_energy = arr.coords["eV"].max().item()
else:
warnings.warn(
"Energyi notation is not specified. Assume the Binding energy notatation",
stacklevel=2,
)
max_kinetic_energy = max(
arr.coords["eV"].values.max(),
arr.coords["eV"].max().item(),
arr.S.hv - arr.S.analyzer_work_function,
)
kps = K_INV_ANGSTROM * np.sqrt(max_kinetic_energy) * np.sin(sampled_phi_values) * np.cos(beta)
Expand Down Expand Up @@ -272,12 +279,22 @@ def calculate_kx_ky_bounds(
)
# Sample hopefully representatively along the edges
phi_low, phi_high = np.min(phi_coords), np.max(phi_coords)
beta_low, beta_high = np.min(beta_coords), np.max(beta_coords)
phi_mid = (phi_high + phi_low) / 2
beta_mid = (beta_high + beta_low) / 2
sampled_phi_values = np.array(
[phi_high, phi_high, phi_mid, phi_low, phi_low, phi_low, phi_mid, phi_high, phi_high],
[
phi_high,
phi_high,
phi_mid,
phi_low,
phi_low,
phi_low,
phi_mid,
phi_high,
phi_high,
],
)
beta_low, beta_high = np.min(beta_coords), np.max(beta_coords)
beta_mid = (beta_high + beta_low) / 2
sampled_beta_values = np.array(
[
beta_mid,
Expand All @@ -293,18 +310,18 @@ def calculate_kx_ky_bounds(
)
if arr.S.energy_notation == "Binding":
kinetic_energy = max(
arr.coords["eV"].values.max(),
arr.coords["eV"].max().item(),
arr.S.hv - arr.S.analyzer_work_function,
)
elif arr.S.energy_notation == "Kinetic":
kinetic_energy = max(arr.coords["eV"].values.max(), -arr.S.analyzer_work_function)
kinetic_energy = arr.coords["eV"].max().item()
else:
warnings.warn(
"Energy notation is not specified. Assume the Binding energy notation",
stacklevel=2,
)
kinetic_energy = max(
arr.coords["eV"].values.max(),
arr.coords["eV"].max().item(),
arr.S.hv - arr.S.analyzer_work_function,
)
# note that the type of the kinetic_energy is float in below.
Expand Down
9 changes: 7 additions & 2 deletions src/arpes/utilities/conversion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def convert_coordinates(
ordered_source_dimensions = arr.dims

grid_interpolator = grid_interpolator_from_dataarray(
arr.transpose(*ordered_source_dimensions),
arr.transpose(*ordered_source_dimensions), # TODO(RA): No need? -- perhaps no.
fill_value=np.nan,
)

Expand Down Expand Up @@ -532,7 +532,12 @@ def acceptable_coordinate(c: NDArray[np.float_] | xr.DataArray) -> bool:
attrs=arr.attrs,
)
old_mapped_coords = [
xr.DataArray(values, target_coordinates, coordinate_transform["dims"], attrs=arr.attrs)
xr.DataArray(
values,
target_coordinates,
coordinate_transform["dims"],
attrs=arr.attrs,
)
for values in old_dimensions
]
if as_dataset:
Expand Down
10 changes: 5 additions & 5 deletions src/arpes/utilities/conversion/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,9 @@ def convert_coordinates_to_kspace_forward(arr: XrTypes) -> xr.Dataset:
arr = arr.copy(deep=True)

skip = {"eV", "cycle", "delay", "T"}
keep = {
"eV",
}
keep = {"eV"}
all_indexes = {k: v for k, v in arr.indexes.items() if k not in skip}
kept = {k: v for k, v in arr.indexes.items() if k in keep}
kept = {k: v for k, v in arr.indexes.items() if k in keep} # TODO (RA): v has not been used.
momentum_compatibles: list[str] = list(all_indexes.keys())
momentum_compatibles.sort()
if not momentum_compatibles:
Expand All @@ -425,7 +423,9 @@ def convert_coordinates_to_kspace_forward(arr: XrTypes) -> xr.Dataset:
("hv", "phi", "psi"): ["kx", "ky", "kz"],
("chi", "hv", "phi"): ["kx", "ky", "kz"],
}.get(tuple(momentum_compatibles), [])
full_old_dims: list[str] = momentum_compatibles + list(kept.keys())
full_old_dims: list[str] = momentum_compatibles + list(
kept.keys(),
) # TODO (RA): list(kept.keys()) can (should) be replaced with ["eV"]
projection_vectors: NDArray[np.float_] = np.ndarray(
shape=tuple(len(arr.coords[d]) for d in full_old_dims),
dtype=object,
Expand Down
16 changes: 10 additions & 6 deletions src/arpes/utilities/conversion/kx_ky_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _compute_ktot(
k_tot: NDArray[np.float_],
) -> None:
for i in numba.prange(len(binding_energy)):
k_tot[i] = K_INV_ANGSTROM * np.sqrt(hv - work_function + binding_energy[i])
k_tot[i] = K_INV_ANGSTROM * np.sqrt(
hv - work_function + binding_energy[i],
)


def _safe_compute_k_tot(
Expand Down Expand Up @@ -160,10 +162,9 @@ def get_coordinates(
Returns:
dict[str, NDArray]: the key represents the axis name suchas "kp", "kx", and "eV".
"""
if resolution is None:
resolution = {}
if bounds is None:
bounds = {}
resolution = resolution if resolution is not None else {}
bounds = bounds if bounds is not None else {}

coordinates = super().get_coordinates(resolution, bounds=bounds)
(kp_low, kp_high) = calculate_kp_bounds(self.arr)
if "kp" in bounds:
Expand Down Expand Up @@ -394,7 +395,10 @@ def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]:
"theta": self.kspace_to_perp_angle,
"psi": self.kspace_to_perp_angle,
"beta": self.kspace_to_perp_angle,
}.get(dim, _with_identity)
}.get(
dim,
_with_identity,
)

@property
def needs_rotation(self) -> bool:
Expand Down

0 comments on commit 223f7a8

Please sign in to comment.