Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 9, 2024
1 parent 8ebd89d commit 3d459d6
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 20 deletions.
10 changes: 6 additions & 4 deletions arpes/analysis/self_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
BareBandType: TypeAlias = xr.DataArray | str | lf.model.ModelResult


def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray:
def get_peak_parameter(
data: xr.DataArray, # values is used
parameter_name: str,
) -> xr.DataArray:
"""Extracts a parameter from a potentially prefixed peak-like component.
Works so long as there is only a single peak defined in the model.
Expand Down Expand Up @@ -66,8 +69,7 @@ def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray:
def local_fermi_velocity(bare_band: xr.DataArray) -> float:
"""Calculates the band velocity under assumptions of a linear bare band."""
fitted_model = LinearModel().guess_fit(bare_band)
raw_velocity = fitted_model.params["slope"].value

raw_velocity: float = fitted_model.params["slope"].value
if "eV" in bare_band.dims:
# the "y" values are in `bare_band` are momenta and the "x" values are energy, therefore
# the slope is dy/dx = dk/dE
Expand Down Expand Up @@ -173,7 +175,7 @@ def quasiparticle_mean_free_path(


def to_self_energy(
dispersion: xr.DataArray,
dispersion: xr.Dataset,
bare_band: BareBandType | None = None,
fermi_velocity: float = 0,
*,
Expand Down
4 changes: 3 additions & 1 deletion arpes/corrections/fermi_edge_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from _typeshed import Incomplete



def _exclude_from_set(excluded):
def exclude(_):
return list(set(_).difference(excluded))
Expand Down Expand Up @@ -96,6 +97,7 @@ def apply_direct_fermi_edge_correction(
if correction is None:
correction = build_direct_fermi_edge_correction(arr, *args, **kwargs)

assert isinstance(correction, xr.Dataset)
shift_amount = (
-correction / arr.G.stride(generic_dim_names=False)["eV"]
) # pylint: disable=invalid-unary-operand-type
Expand Down Expand Up @@ -242,7 +244,7 @@ def apply_photon_energy_fermi_edge_correction(
"""
if correction is None:
correction = build_photon_energy_fermi_edge_correction(arr, **kwargs)

assert isinstance(correction, xr.Dataset)
correction_values = correction.G.map(lambda x: x.params["center"].value)
if "corrections" not in arr.attrs:
arr.attrs["corrections"] = {}
Expand Down
3 changes: 1 addition & 2 deletions arpes/fits/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def broadcast_model( # noqa: PLR0913
Args:
model_cls: The model specification
data: The data to curve fit
data: The data to curve fit (Should be DataArray)
broadcast_dims: Which dimensions of the input should be iterated across as opposed
to fit across
params: Parameter hints, consisting of plain values or arrays for interpolation
Expand Down Expand Up @@ -242,7 +242,6 @@ def unwrap(result_data: str) -> object: # (Unpickler)
template.loc[coords] = np.array(fit_result)
residual.loc[coords] = fit_residual

logger.debug("Bundling into dataset")
return xr.Dataset(
{
"results": template,
Expand Down
4 changes: 2 additions & 2 deletions arpes/laue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
byte 65823 / 131780 / 3004 = kV
byte 131664 / 592 = index file name
"""

from __future__ import annotations

from pathlib import Path
Expand Down Expand Up @@ -42,9 +43,8 @@ def load_laue(path: Path | str) -> xr.DataArray:
if isinstance(path, str):
path = Path(path)

binary_data = path.read_bytes()
binary_data: bytes = path.read_bytes()
table, header = binary_data[:131072], binary_data[131072:]

table = np.fromstring(table, dtype=np.uint16).reshape(256, 256)
header = np.fromstring(header, dtype=northstar_62_69_dtype).item()

Expand Down
12 changes: 6 additions & 6 deletions arpes/preparation/tof_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Callable

from numpy.typing import NDArray

Expand All @@ -28,7 +28,7 @@
@update_provenance("Convert ToF data from timing signal to kinetic energy")
def convert_to_kinetic_energy(
dataarray: xr.DataArray,
kinetic_energy_axis: Sequence[float],
kinetic_energy_axis: NDArray[np.float_],
) -> xr.DataArray:
"""Convert the ToF timing information into an energy histogram.
Expand Down Expand Up @@ -112,7 +112,7 @@ def energy_to_time(conv: float, energy: float) -> float:

def build_KE_coords_to_time_pixel_coords(
dataset: xr.Dataset,
interpolation_axis: Sequence[float],
interpolation_axis: NDArray[np.float_],
) -> Callable[..., tuple[xr.DataArray]]:
"""Constructs a coordinate conversion function from kinetic energy to time pixels."""
conv = (
Expand Down Expand Up @@ -222,12 +222,12 @@ def convert_SToF_to_energy(dataset: xr.Dataset) -> xr.Dataset:
spacing = dataset.attrs.get("dE", 0.005)
ke_axis = np.linspace(e_min, e_max, int((e_max - e_min) / spacing))

drs = {k: v for k, v in dataset.data_vars.items() if "time" in v.dims}
drs = {k: spectrum for k, spectrum in dataset.data_vars.items() if "time" in spectrum.dims}

new_dataarrays = [convert_to_kinetic_energy(dr, ke_axis) for dr in drs.values()]
new_dataarrays = [convert_to_kinetic_energy(dr, ke_axis) for dr in drs]

for v in new_dataarrays:
dataset[v.name.replace("t_", "")] = v
dataset[str(v.name).replace("t_", "")] = v

return dataset

Expand Down
14 changes: 9 additions & 5 deletions arpes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from functools import wraps
from pathlib import Path
from pprint import pprint
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ParamSpec, TypeVar

from logging import INFO, Formatter, StreamHandler, getLogger

Expand Down Expand Up @@ -73,13 +73,17 @@
logger.propagate = False


def with_workspace(f: Callable) -> Callable:
P = ParamSpec("P")
R = TypeVar("R")


def with_workspace(f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def wrapped_with_workspace(
*args,
*args: P.args,
workspace: str | None = None,
**kwargs: Incomplete,
):
**kwargs: P.kwargs,
) -> R:
"""[TODO:summary].
Args:
Expand Down

0 comments on commit 3d459d6

Please sign in to comment.