From 3d459d6042cfc52c31c77c796126425fa80cdbcd Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 9 Feb 2024 11:43:31 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/analysis/self_energy.py | 10 ++++++---- arpes/corrections/fermi_edge_corrections.py | 4 +++- arpes/fits/utilities.py | 3 +-- arpes/laue/__init__.py | 4 ++-- arpes/preparation/tof_preparation.py | 12 ++++++------ arpes/workflow.py | 14 +++++++++----- 6 files changed, 27 insertions(+), 20 deletions(-) diff --git a/arpes/analysis/self_energy.py b/arpes/analysis/self_energy.py index f5a53807..4af5817d 100644 --- a/arpes/analysis/self_energy.py +++ b/arpes/analysis/self_energy.py @@ -27,7 +27,10 @@ BareBandType: TypeAlias = xr.DataArray | str | lf.model.ModelResult -def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray: +def get_peak_parameter( + data: xr.DataArray, # values is used + parameter_name: str, +) -> xr.DataArray: """Extracts a parameter from a potentially prefixed peak-like component. Works so long as there is only a single peak defined in the model. @@ -66,8 +69,7 @@ def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray: def local_fermi_velocity(bare_band: xr.DataArray) -> float: """Calculates the band velocity under assumptions of a linear bare band.""" fitted_model = LinearModel().guess_fit(bare_band) - raw_velocity = fitted_model.params["slope"].value - + raw_velocity: float = fitted_model.params["slope"].value if "eV" in bare_band.dims: # the "y" values are in `bare_band` are momenta and the "x" values are energy, therefore # the slope is dy/dx = dk/dE @@ -173,7 +175,7 @@ def quasiparticle_mean_free_path( def to_self_energy( - dispersion: xr.DataArray, + dispersion: xr.Dataset, bare_band: BareBandType | None = None, fermi_velocity: float = 0, *, diff --git a/arpes/corrections/fermi_edge_corrections.py b/arpes/corrections/fermi_edge_corrections.py index 8ab172a8..81dda952 100644 --- a/arpes/corrections/fermi_edge_corrections.py +++ b/arpes/corrections/fermi_edge_corrections.py @@ -18,6 +18,7 @@ from _typeshed import Incomplete + def _exclude_from_set(excluded): def exclude(_): return list(set(_).difference(excluded)) @@ -96,6 +97,7 @@ def apply_direct_fermi_edge_correction( if correction is None: correction = build_direct_fermi_edge_correction(arr, *args, **kwargs) + assert isinstance(correction, xr.Dataset) shift_amount = ( -correction / arr.G.stride(generic_dim_names=False)["eV"] ) # pylint: disable=invalid-unary-operand-type @@ -242,7 +244,7 @@ def apply_photon_energy_fermi_edge_correction( """ if correction is None: correction = build_photon_energy_fermi_edge_correction(arr, **kwargs) - + assert isinstance(correction, xr.Dataset) correction_values = correction.G.map(lambda x: x.params["center"].value) if "corrections" not in arr.attrs: arr.attrs["corrections"] = {} diff --git a/arpes/fits/utilities.py b/arpes/fits/utilities.py index 3d7a2183..178968fe 100644 --- a/arpes/fits/utilities.py +++ b/arpes/fits/utilities.py @@ -138,7 +138,7 @@ def broadcast_model( # noqa: PLR0913 Args: model_cls: The model specification - data: The data to curve fit + data: The data to curve fit (Should be DataArray) broadcast_dims: Which dimensions of the input should be iterated across as opposed to fit across params: Parameter hints, consisting of plain values or arrays for interpolation @@ -242,7 +242,6 @@ def unwrap(result_data: str) -> object: # (Unpickler) template.loc[coords] = np.array(fit_result) residual.loc[coords] = fit_residual - logger.debug("Bundling into dataset") return xr.Dataset( { "results": template, diff --git a/arpes/laue/__init__.py b/arpes/laue/__init__.py index 7c60090d..007b5586 100644 --- a/arpes/laue/__init__.py +++ b/arpes/laue/__init__.py @@ -14,6 +14,7 @@ byte 65823 / 131780 / 3004 = kV byte 131664 / 592 = index file name """ + from __future__ import annotations from pathlib import Path @@ -42,9 +43,8 @@ def load_laue(path: Path | str) -> xr.DataArray: if isinstance(path, str): path = Path(path) - binary_data = path.read_bytes() + binary_data: bytes = path.read_bytes() table, header = binary_data[:131072], binary_data[131072:] - table = np.fromstring(table, dtype=np.uint16).reshape(256, 256) header = np.fromstring(header, dtype=northstar_62_69_dtype).item() diff --git a/arpes/preparation/tof_preparation.py b/arpes/preparation/tof_preparation.py index 643177c6..abfce961 100644 --- a/arpes/preparation/tof_preparation.py +++ b/arpes/preparation/tof_preparation.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable from numpy.typing import NDArray @@ -28,7 +28,7 @@ @update_provenance("Convert ToF data from timing signal to kinetic energy") def convert_to_kinetic_energy( dataarray: xr.DataArray, - kinetic_energy_axis: Sequence[float], + kinetic_energy_axis: NDArray[np.float_], ) -> xr.DataArray: """Convert the ToF timing information into an energy histogram. @@ -112,7 +112,7 @@ def energy_to_time(conv: float, energy: float) -> float: def build_KE_coords_to_time_pixel_coords( dataset: xr.Dataset, - interpolation_axis: Sequence[float], + interpolation_axis: NDArray[np.float_], ) -> Callable[..., tuple[xr.DataArray]]: """Constructs a coordinate conversion function from kinetic energy to time pixels.""" conv = ( @@ -222,12 +222,12 @@ def convert_SToF_to_energy(dataset: xr.Dataset) -> xr.Dataset: spacing = dataset.attrs.get("dE", 0.005) ke_axis = np.linspace(e_min, e_max, int((e_max - e_min) / spacing)) - drs = {k: v for k, v in dataset.data_vars.items() if "time" in v.dims} + drs = {k: spectrum for k, spectrum in dataset.data_vars.items() if "time" in spectrum.dims} - new_dataarrays = [convert_to_kinetic_energy(dr, ke_axis) for dr in drs.values()] + new_dataarrays = [convert_to_kinetic_energy(dr, ke_axis) for dr in drs] for v in new_dataarrays: - dataset[v.name.replace("t_", "")] = v + dataset[str(v.name).replace("t_", "")] = v return dataset diff --git a/arpes/workflow.py b/arpes/workflow.py index 91582fb8..52131604 100644 --- a/arpes/workflow.py +++ b/arpes/workflow.py @@ -33,7 +33,7 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ParamSpec, TypeVar from logging import INFO, Formatter, StreamHandler, getLogger @@ -73,13 +73,17 @@ logger.propagate = False -def with_workspace(f: Callable) -> Callable: +P = ParamSpec("P") +R = TypeVar("R") + + +def with_workspace(f: Callable[P, R]) -> Callable[P, R]: @wraps(f) def wrapped_with_workspace( - *args, + *args: P.args, workspace: str | None = None, - **kwargs: Incomplete, - ): + **kwargs: P.kwargs, + ) -> R: """[TODO:summary]. Args: