Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Oct 5, 2023
1 parent aa9b9fd commit 9db5896
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 91 deletions.
4 changes: 4 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Original author does not seem to maintain this package. Every bug must be fixed
- Check the functionality

- ConversionKxKy class

- .band_analysis_utils import param_getter, param_stderr_getter
- Check type of the argument set at lf.Mmodel: Is it really lf.Model? lf.ModelResult is better?


- rye for packaging
- tidiy up yaml files
33 changes: 18 additions & 15 deletions arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from collections.abc import Generator

import lmfit as lf
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import DataType
Expand Down Expand Up @@ -130,8 +131,8 @@ def unpack_bands_from_fit(

identified_band_results = copy.deepcopy(band_results)

def as_vector(model_fit: lf.Model, prefix: str = "") -> NDArray[np.float_]:
"""Convert lf.Model to NDArray.
def as_vector(model_fit: lf.ModelResult, prefix: str = "") -> NDArray[np.float_]:
"""Convert lf.ModelResult to NDArray.
Args:
model_fit ([TODO:type]): [TODO:description]
Expand Down Expand Up @@ -251,7 +252,7 @@ def dataarray_for_value(param_name: str, i: int = i, *, is_value: bool) -> xr.Da
@update_provenance("Fit bands from pattern")
def fit_pktterned_bands(
arr: xr.DataArray,
band_set,
band_set: Incomplete,
fit_direction: str = "",
stray: float | None = None,
*,
Expand Down Expand Up @@ -296,14 +297,14 @@ def fit_pktterned_bands(
free_directions.remove(fit_direction)

def resolve_partial_bands_from_description(
coord_dict,
coord_dict: dict[str, Incomplete],
name: str = "",
band=arpes.models.band.Band,
dims: list[str] | tuple[str, ...] | None = None,
params=None,
points=None,
marginal=None,
):
params: Incomplete = None,
points: Incomplete = None,
marginal: Incomplete = None,
) -> list[dict[str, Any]]:
# You don't need to supply a marginal, but it is useful because it allows estimation of the
# initial value for the amplitude from the approximate peak location

Expand Down Expand Up @@ -411,7 +412,7 @@ def _is_between(x: float, y0: float, y1: float) -> bool:
return y0 <= x <= y1


def _instantiate_band(partial_band: dict[str, ...]):
def _instantiate_band(partial_band: dict[str, ...]) -> lf.Model:
phony_band = partial_band["band"](partial_band["name"])
built = phony_band.fit_cls(prefix=partial_band["name"], missing="drop")
for constraint_coord, params in partial_band["params"].items():
Expand All @@ -423,11 +424,11 @@ def _instantiate_band(partial_band: dict[str, ...]):

def fit_bands(
arr: xr.DataArray,
band_description,
band_description: Incomplete,
direction: Literal["edc", "mdc", "EDC", "MDC"] = "mdc",
preferred_k_direction=None,
step=None,
):
preferred_k_direction: str = "",
step: Literal["initial", None] = None,
) -> tuple[xr.DataArray | None, None, lf.ModelResult | None]:
"""Fits bands and determines dispersion in some region of a spectrum.
Args:
Expand All @@ -445,7 +446,9 @@ def fit_bands(

broadcast_direction = "eV"

if direction == "mdc" and preferred_k_direction is None:
if (
direction == "mdc" and not preferred_k_direction
): # TODO: Need to check (Is preferred_k_direction is required?)
possible_directions = set(directions).intersection({"kp", "kx", "ky", "phi"})
broadcast_direction = next(iter(possible_directions))

Expand Down Expand Up @@ -538,7 +541,7 @@ def fit_bands(
unpacked_bands = None
residual = None

return band_results, unpacked_bands, residual
return band_results, unpacked_bands, residual # Memo bunt_result is xr.DataArray


def _interpolate_intersecting_fragments(coord, coord_index, points):
Expand Down
10 changes: 5 additions & 5 deletions arpes/analysis/band_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

if TYPE_CHECKING:
from numpy.typing import NDArray
from collections.abc import Callable


class ParamType(NamedTuple):
Expand All @@ -16,7 +16,7 @@ class ParamType(NamedTuple):
stderr: float


def param_getter(param_name: ParamType, *, safe: bool = True) -> NDArray[np.float_]:
def param_getter(param_name: ParamType, *, safe: bool = True) -> Callabe[..., float]:
"""Constructs a function to extract a parameter value by name.
Useful to extract data from inside an array of `lmfit.ModelResult` instances.
Expand All @@ -34,7 +34,7 @@ def param_getter(param_name: ParamType, *, safe: bool = True) -> NDArray[np.floa
if safe:
safe_param = ParamType(value=np.nan, stderr=np.nan)

def getter(x) -> NDArray[np.float_]:
def getter(x: lf.ModelResult) -> float:
try:
return x.params.get(param_name, safe_param).value
except:
Expand All @@ -45,7 +45,7 @@ def getter(x) -> NDArray[np.float_]:
return lambda x: x.params[param_name].value


def param_stderr_getter(param_name: ParamType, *, safe: bool = True) -> NDArray[np.float_]:
def param_stderr_getter(param_name: ParamType, *, safe: bool = True) -> Callable[..., float]:
"""Constructs a function to extract a parameter value by name.
Useful to extract data from inside an array of `lmfit.ModelResult` instances.
Expand All @@ -63,7 +63,7 @@ def param_stderr_getter(param_name: ParamType, *, safe: bool = True) -> NDArray[
if safe:
safe_param = ParamType(value=np.nan, stderr=np.nan)

def getter(x) -> NDArray[np.float_]:
def getter(x: lf.MdoelResult) -> float:
try:
return x.params.get(param_name, safe_param).stderr
except:
Expand Down
17 changes: 8 additions & 9 deletions arpes/fits/fit_models/fermi_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def guess(
pars["%slorcenter" % self.prefix].set(value=0)
pars["%slin_bkg" % self.prefix].set(value=0)
pars["%sconst_bkg" % self.prefix].set(value=data.min())
# TODO: we can do better than this

pars["%swidth" % self.prefix].set(0.02)
pars["%serf_amp" % self.prefix].set(value=data.mean() - data.min())

Expand Down Expand Up @@ -280,7 +280,7 @@ def guess(
pars["%scenter" % self.prefix].set(value=0)
pars["%slin_bkg" % self.prefix].set(value=0)
pars["%sconst_bkg" % self.prefix].set(value=data.min())
# TODO: we can do better than this

pars["%swidth" % self.prefix].set(0.02)
pars["%serf_amp" % self.prefix].set(value=data.mean() - data.min())

Expand Down Expand Up @@ -355,7 +355,7 @@ def guess(
pars["%soffset" % self.prefix].set(value=data.min())

pars["%scenter" % self.prefix].set(value=0)
# TODO: we can do better than this

pars["%swidth" % self.prefix].set(0.02)

return update_param_vals(pars, self.prefix, **kwargs)
Expand Down Expand Up @@ -415,7 +415,7 @@ def guess(
pars["%soffset" % self.prefix].set(value=data.min())

pars["%scenter" % self.prefix].set(value=0)
# TODO: we can do better than this

pars["%swidth" % self.prefix].set(0.02)

return update_param_vals(pars, self.prefix, **kwargs)
Expand Down Expand Up @@ -484,7 +484,7 @@ def guess(
Args:
data: ARPES data
x (NONE):
x (NDArray[np._float],NONE): as variable "x"
kwargs: [TODO:description]
Returns:
Expand All @@ -505,7 +505,6 @@ def guess(
pars["%slin_bkg" % self.prefix].set(value=0)
pars["%soffset" % self.prefix].set(value=data.min())

# TODO: we can do better than this
pars["%swidth" % self.prefix].set(0.02)

return update_param_vals(pars, self.prefix, **kwargs)
Expand Down Expand Up @@ -644,7 +643,7 @@ def guess(
pars["%scenter" % self.prefix].set(value=0)
pars["%slin_bkg" % self.prefix].set(value=0)
pars["%sconst_bkg" % self.prefix].set(value=data.min())
# TODO: we can do better than this

pars["%ssigma" % self.prefix].set(0.02)
pars["%serf_amp" % self.prefix].set(value=data.mean() - data.min())

Expand Down Expand Up @@ -711,7 +710,7 @@ def guess(
pars["%scenter" % self.prefix].set(value=0)
pars["%slin_bkg" % self.prefix].set(value=0)
pars["%sconst_bkg" % self.prefix].set(value=data.min())
# TODO: we can do better than this

pars["%ssigma" % self.prefix].set(0.02)
pars["%samplitude" % self.prefix].set(value=data.mean() - data.min())

Expand Down Expand Up @@ -787,7 +786,7 @@ def guess(
pars["%sg_center" % self.prefix].set(value=0)
pars["%slin_bkg" % self.prefix].set(value=0)
pars["%sconst_bkg" % self.prefix].set(value=data.min())
# TODO: we can do better than this

pars["%sgamma" % self.prefix].set(0.02)
pars["%st_gamma" % self.prefix].set(0.02)
pars["%ssigma" % self.prefix].set(0.02)
Expand Down
55 changes: 39 additions & 16 deletions arpes/models/band.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""Rudimentary band analyis code."""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import scipy.ndimage.filters
import xarray as xr

import arpes.fits
from arpes.analysis.band_analysis_utils import param_getter, param_stderr_getter

if TYPE_CHECKING:
import lmfit as lf
from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = [
"Band",
"MultifitBand",
Expand All @@ -19,7 +27,12 @@
class Band:
"""Representation of an ARPES band which supports some calculations after fitting."""

def __init__(self, label: str, display_label: str | None = None, data=None) -> None:
def __init__(
self,
label: str,
display_label: str | None = None,
data: DataType | None = None,
) -> None:
"""Set the data but don't perform any calculation eagerly."""
self.label = label
self._display_label = display_label
Expand Down Expand Up @@ -79,17 +92,23 @@ def self_energy(self):
return

@property
def fit_cls(self):
def fit_cls(self) -> lf.Model:
"""Describes which fit class to use for band fitting, default Lorentzian."""
return arpes.fits.LorentzianModel

def get_dataarray(self, var_name, *, clean: bool = True):
def get_dataarray(
self,
var_name: str,
*,
clean: bool = True,
) -> xr.DataArray | NDArray[np.float_]:
"""Converts the underlying data into an array representation."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
if not clean:
return self._data[var_name].values

output = np.copy(self._data[var_name].values)
output[self._data[var_name + "_stderr"].values > 0.01] = float("nan")
output[self._data[var_name + "_stderr"].values > 0.01] = np.nan

return xr.DataArray(
output,
Expand All @@ -98,46 +117,50 @@ def get_dataarray(self, var_name, *, clean: bool = True):
)

@property
def center(self):
def center(self) -> xr.DataArray:
"""Gets the peak location along the band."""
return self.get_dataarray("center")

@property
def center_stderr(self):
def center_stderr(self) -> xr.DataArray:
"""Gets the peak location stderr along the band."""
return self.get_dataarray("center_stderr", False)
return self.get_dataarray("center_stderr", clean=False)

@property
def sigma(self):
def sigma(self) -> xr.DataArray:
"""Gets the peak width along the band."""
return self.get_dataarray("sigma", True)
return self.get_dataarray("sigma", clean=True)

@property
def amplitude(self):
def amplitude(self) -> xr.DataArray:
"""Gets the peak amplitude along the band."""
return self.get_dataarray("amplitude", True)
return self.get_dataarray("amplitude", clean=True)

@property
def indexes(self):
"""Fetches the indices of the originating data (after fit reduction)."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
return self._data.center.indexes

@property
def coords(self):
def coords(self) -> xr.DataArray:
"""Fetches the coordinates of the originating data (after fit reduction)."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
return self._data.center.coords

@property
def dims(self):
def dims(self) -> tuple[str, ...]:
"""Fetches the dimensions of the originating data (after fit reduction)."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
return self._data.center.dims


class MultifitBand(Band):
"""Convenience class that reimplements reading data out of a composite fit result."""

def get_dataarray(self, var_name, clean=True):
def get_dataarray(self, var_name: str, *, clean: bool = True):
"""Converts the underlying data into an array representation."""
assert isinstance(self._data, xr.DataArray | xr.Dataset)
full_var_name = self.label + var_name

if "stderr" in full_var_name:
Expand All @@ -150,7 +173,7 @@ class VoigtBand(Band):
"""Uses a Voigt lineshape."""

@property
def fit_cls(self):
def fit_cls(self) -> lf.Model:
"""Fit using `arpes.fits.VoigtModel`."""
return arpes.fits.VoigtModel

Expand All @@ -159,6 +182,6 @@ class BackgroundBand(Band):
"""Uses a Gaussian lineshape."""

@property
def fit_cls(self):
def fit_cls(self) -> lf.Model:
"""Fit using `arpes.fits.GaussianModel`."""
return arpes.fits.GaussianModel
Loading

0 comments on commit 9db5896

Please sign in to comment.