Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Sep 29, 2023
1 parent b632560 commit 920f500
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 50 deletions.
2 changes: 1 addition & 1 deletion arpes/analysis/band_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def param_getter(param_name: ParamType, *, safe: bool = True) -> NDArray[np.floa
if safe:
safe_param = ParamType(value=np.nan, stderr=np.nan)

def getter(x):
def getter(x) -> NDArray[np.float_]:
try:
return x.params.get(param_name, safe_param).value
except:
Expand Down
6 changes: 3 additions & 3 deletions arpes/analysis/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def normalize_by_fermi_distribution(
rigid_shift: float = 0,
instrumental_broadening: float = 0,
total_broadening: float = 0,
):
) -> xr.DataArray:
"""Normalizes a scan by 1/the fermi dirac distribution.
You can control the maximum gain with ``clamp``, and whether
Expand Down Expand Up @@ -107,12 +107,12 @@ def normalize_by_fermi_distribution(

@update_provenance("Symmetrize about axis")
def symmetrize_axis(
data,
data: DataType,
axis_name: str,
flip_axes: list[str] | None = None,
*,
shift_axis: bool = True,
):
) -> xr.DataArray:
"""Symmetrizes data across an axis.
It would be better ultimately to be able
Expand Down
16 changes: 12 additions & 4 deletions arpes/analysis/kfermi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""Tools related to finding the Fermi momentum in a cut."""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from arpes._typing import DataType
from arpes.fits import LinearModel

if TYPE_CHECKING:
from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = ("kfermi_from_mdcs",)


def kfermi_from_mdcs(mdc_results: DataType, param=None):
def kfermi_from_mdcs(mdc_results: DataType, param: str = "") -> NDArray[np.float_]:
"""Calculates a Fermi momentum using a series of MDCs and the known Fermi level (eV=0).
This is especially useful to isolate an area for analysis.
Expand All @@ -30,13 +38,13 @@ def kfermi_from_mdcs(mdc_results: DataType, param=None):
real_param_name = param
else:
best_names = [p for p in param_names if "center" in p]
if param is not None:
if not param:
best_names = [p for p in best_names if param in p]

assert len(best_names) == 1
real_param_name = best_names[0]

def nan_sieve(_, x):
def nan_sieve(_, x) -> bool:
return not np.isnan(x.item())

return (
Expand Down
15 changes: 10 additions & 5 deletions arpes/analysis/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

if TYPE_CHECKING:
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import DataType

Expand All @@ -20,7 +21,11 @@


@update_provenance("Discretize Path")
def discretize_path(path: xr.Dataset, n_points: int = 0, scaling=None) -> xr.Dataset:
def discretize_path(
path: xr.Dataset,
n_points: int = 0,
scaling: float | xr.Dataset | dict[str, NDArray[np.float_]] | None = None,
) -> xr.Dataset:
"""Discretizes a path into a set of points spaced along the path.
Shares logic with slice_along_path
Expand All @@ -38,7 +43,7 @@ def discretize_path(path: xr.Dataset, n_points: int = 0, scaling=None) -> xr.Dat
if scaling is None:
scaling = 1
elif isinstance(scaling, xr.Dataset):
scaling = {k: scaling[k].item() for k in scaling.data_vars}
scaling = {str(k): scaling[k].item() for k in scaling.data_vars}
else:
assert isinstance(scaling, dict)

Expand Down Expand Up @@ -81,7 +86,7 @@ def distance(a, b):

new_index = np.array(range(len(points)))

def to_dataarray(name):
def to_dataarray(name: str) -> xr.DataArray:
index = order.index(name)
data = [p[index] for p in points]

Expand All @@ -94,11 +99,11 @@ def to_dataarray(name):
def select_along_path(
path: xr.Dataset,
data: DataType,
radius=None,
radius: float = 0,
n_points: int = 0,
*,
fast: bool = True,
scaling=None,
scaling: float | xr.Dataset | dict[str, NDArray[np.float_]] | None = None,
**kwargs: Incomplete,
) -> DataType:
"""Performs integration along a path.
Expand Down
14 changes: 9 additions & 5 deletions arpes/analysis/resolution.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
"""Contains calibrations and information for spectrometer resolution."""
from __future__ import annotations

import math
from typing import TYPE_CHECKING, Any

import numpy as np

from arpes._typing import DataType

# all resolutions are given by (photon energy, entrance slit, exit slit size)
from arpes.constants import K_BOLTZMANN_MEV_KELVIN
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
from arpes._typing import DataType

__all__ = ("total_resolution_estimate",)


# all analyzer dimensions are given in millimeters for convenience as this
# is how slit sizes are typically reported
def r8000(slits):
def r8000(slits) -> dict[str, Any]:
return {
"type": "HEMISPHERE",
"slits": slits,
Expand All @@ -24,9 +28,9 @@ def r8000(slits):


def analyzer_resolution(
analyzer_information,
analyzer_information: dict[str, Any],
slit_width: float | None = None,
slit_number=None,
slit_number: int | None = None,
pass_energy: float = 10,
) -> float:
"""Estimates analyzer resolution from slit dimensioons pass energy, and analyzer radius.
Expand Down
15 changes: 9 additions & 6 deletions arpes/analysis/self_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

if TYPE_CHECKING:
from _typeshed import Incomplete
from numpy.typing import NDArray

__all__ = (
"to_self_energy",
Expand Down Expand Up @@ -144,7 +145,9 @@ def estimate_bare_band(
return xr.DataArray(ys, centers.coords, centers.dims)


def quasiparticle_lifetime(self_energy: xr.DataArray, bare_band: xr.DataArray) -> xr.DataArray:
def quasiparticle_lifetime(
self_energy: xr.DataArray,
) -> NDArray[np.float_]:
"""Calculates the quasiparticle mean free path in meters (meters!).
The bare band is used to calculate the band/Fermi velocity
Expand Down Expand Up @@ -172,7 +175,7 @@ def quasiparticle_mean_free_path(
def to_self_energy(
dispersion: xr.DataArray,
bare_band: BareBandType | None = None,
fermi_velocity: float | None = None,
fermi_velocity: float = 0,
*,
k_independent: bool = True,
) -> xr.Dataset:
Expand All @@ -196,9 +199,9 @@ def to_self_energy(
to the $\gamma$ parameter, which defines the imaginary part of the self energy.
Args:
dispersion
bare_band
fermi_velocity
dispersion ():
bare_band ():
fermi_velocity (float): The fermi velocity. If not set, use local_fermi_velocity
k_independent: bool
Returns:
Expand All @@ -217,7 +220,7 @@ def to_self_energy(
from_mdcs = "eV" in dispersion.dims # if eV is in the dimensions, then we fitted MDCs
estimated_bare_band = estimate_bare_band(dispersion, bare_band)

if fermi_velocity is None:
if not fermi_velocity:
fermi_velocity = local_fermi_velocity(estimated_bare_band)
assert isinstance(fermi_velocity, float)

Expand Down
14 changes: 7 additions & 7 deletions arpes/analysis/shirley.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@update_provenance("Remove Shirley background")
def remove_shirley_background(xps: DataType, **kwargs: float) -> xr.DataArray:
def remove_shirley_background(xps: DataType, **kwargs: float | int) -> xr.DataArray:
"""Calculates and removes a Shirley background from a spectrum.
Only the background corrected spectrum is retrieved.
Expand All @@ -41,9 +41,9 @@ def remove_shirley_background(xps: DataType, **kwargs: float) -> xr.DataArray:

def _calculate_shirley_background_full_range(
xps: NDArray[np.float_],
eps=1e-7,
max_iters=50,
n_samples=5,
eps: float = 1e-7,
max_iters: int = 50,
n_samples: int = 5,
) -> NDArray[np.float_]:
"""Core routine for calculating a Shirley background on np.ndarray data."""
background = np.copy(xps)
Expand Down Expand Up @@ -142,9 +142,9 @@ def calculate_shirley_background_full_range(
def calculate_shirley_background(
xps: DataType,
energy_range: slice | None = None,
eps=1e-7,
max_iters=50,
n_samples=5,
eps: float = 1e-7,
max_iters: int = 50,
n_samples: int = 5,
) -> xr.DataArray:
"""Calculates a shirley background iteratively over the full energy range `energy_range`.
Expand Down
6 changes: 3 additions & 3 deletions arpes/analysis/xps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def local_minima(a: NDArray[np.float_], promenance: int = 3) -> NDArray[np.float
return conditions


def local_maxima(a, promenance=3):
def local_maxima(a: NDArray[np.float_], promenance: int = 3) -> NDArray[np.float_]:
return local_minima(-a, promenance)


Expand All @@ -50,8 +50,8 @@ def approximate_core_levels(
window_size: int = 0,
order: int = 5,
binning: int = 3,
promenance=5,
):
promenance: int = 5,
) -> list[NDArray[np.float_]]:
"""Approximately locates core levels in a spectrum.
Data is first smoothed, and then local maxima with sufficient prominence over
Expand Down
2 changes: 1 addition & 1 deletion arpes/fits/fit_models/backgrounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AffineBackgroundModel(XModelMixin):

def __init__(
self,
independent_vars: list | None = None,
independent_vars: list[str] | None = None,
prefix: str = "",
nan_policy: NAN_POLICY = "raise",
**kwargs: Incomplete,
Expand Down
4 changes: 2 additions & 2 deletions arpes/fits/fit_models/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.set_param_hint("n0", min=0.0)
self.set_param_hint("eps", min=0.0)

def guess(self, data: xr.DataArray, x=None, **kwargs: Incomplete) -> lf.Parameters:
def guess(self, data: xr.DataArray, **kwargs: Incomplete) -> lf.Parameters:
"""Placeholder for parameter estimation."""
pars = self.make_params()

Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
self.set_param_hint("alpha", min=0.0)
self.set_param_hint("vF", min=0.0)

def guess(self, data, x=None, **kwargs: Incomplete) -> lf.Parameters:
def guess(self, data, **kwargs: Incomplete) -> lf.Parameters:
"""Placeholder for actually making parameter estimates here."""
pars = self.make_params()

Expand Down
10 changes: 5 additions & 5 deletions arpes/fits/lmfit_html_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from _typeshed import Incomplete


def repr_multiline_ModelResult(self: model.Model, **kwargs: Incomplete) -> str: # noqa: N802
def repr_multiline_ModelResult(self: model.Model, **kwargs: Incomplete) -> str:
"""Provides a text-based multiline representation used in Qt based interactive tools."""
template = "ModelResult\n Converged: {success}\n "
template += "Components:\n {formatted_components}\n Parameters:\n{parameters}"
Expand All @@ -35,7 +35,7 @@ def repr_multiline_ModelResult(self: model.Model, **kwargs: Incomplete) -> str:
)


def repr_html_ModelResult(self, **kwargs: Incomplete):
def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str:
"""Provides a better Jupyter representation of an `lmfit.ModelResult` instance."""
template = """
<div>
Expand All @@ -51,7 +51,7 @@ def repr_html_ModelResult(self, **kwargs: Incomplete):
)


def repr_html_Model(self):
def repr_html_Model(self: Incomplete) -> str:
"""Better Jupyter representation of `lmfit.Model` instances."""
template = """
<div>
Expand All @@ -61,7 +61,7 @@ def repr_html_Model(self):
return template.format(name=self.name)


def repr_multiline_Model(self, **kwargs: Incomplete):
def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str:
"""Provides a text-based multiline representation used in Qt based interactive tools."""
return self.name

Expand All @@ -70,7 +70,7 @@ def repr_multiline_Model(self, **kwargs: Incomplete):
SKIP_ON_SHORT = {"min", "max", "vary", "expr", "brute_step"}


def repr_html_Parameters(self, *, short: bool = False) -> str:
def repr_html_Parameters(self: Incomplete, *, short: bool = False) -> str:
"""HTML representation for `lmfit.Parameters` instances."""
keys = sorted(self.keys())
template = """
Expand Down
4 changes: 2 additions & 2 deletions arpes/plotting/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def annotate_experimental_conditions(
ax.patch.set_alpha(0)

delta = -1
current = 100
current = 100.0
if orientation == "bottom":
delta = 1
current = 0
Expand Down Expand Up @@ -105,7 +105,7 @@ def render_photon(c: dict[str, float]) -> str:
}

for item in desc:
if isinstance(item, float | int):
if isinstance(item, float):
current += item + delta
continue

Expand Down
7 changes: 4 additions & 3 deletions arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import xarray as xr
from _typeshed import Incomplete
from matplotlib.figure import Figure, FigureBase
from matplotlib.colors import Normalize

from arpes._typing import DataType

Expand All @@ -39,7 +40,7 @@


@save_plot_provenance
def plot_dispersion(spectrum: xr.DataArray, bands, out: str | Path = ""):
def plot_dispersion(spectrum: xr.DataArray, bands, out: str | Path = "") -> Axes | Path:
"""Plots an ARPES cut with bands over it."""
ax = spectrum.plot()

Expand Down Expand Up @@ -454,7 +455,7 @@ def fancy_dispersion(
out: str | Path = "",
*,
include_symmetry_points: bool = True,
norm=None,
norm: Normalize | None = None,
**kwargs: Incomplete,
) -> Axes | Path:
"""Generates a 2D ARPES cut with some fancy annotations for throwing plots together.[TODO:summary].
Expand Down Expand Up @@ -528,7 +529,7 @@ def scan_var_reference_plot(
data: DataType,
title: str = "",
ax: Axes | None = None,
norm=None,
norm: Normalize | None = None,
out: str | Path = "",
**kwargs: Incomplete,
) -> None | Path:
Expand Down
Loading

0 comments on commit 920f500

Please sign in to comment.