Skip to content

Commit

Permalink
🔨 add and revise type hints
Browse files Browse the repository at this point in the history
🔨  remove all.py
💬  add type of kwargs for matplotlib (plot, pcolormesh)
  • Loading branch information
arafune committed Sep 26, 2023
1 parent afcee0c commit d17f435
Show file tree
Hide file tree
Showing 19 changed files with 242 additions and 143 deletions.
128 changes: 125 additions & 3 deletions arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import uuid
from typing import TYPE_CHECKING, Literal, Required, TypedDict, TypeVar
from typing import TYPE_CHECKING, Literal, Required, TypeAlias, TypedDict, TypeVar

import xarray as xr

Expand All @@ -20,11 +20,28 @@
##

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from pathlib import Path

import numpy as np
from _typeshed import Incomplete
from numpy.typing import NDArray
from matplotlib.artist import Artist
from matplotlib.backend_bases import Event
from matplotlib.colors import Colormap
from matplotlib.figure import Figure
from matplotlib.patheffects import AbstractPathEffect
from matplotlib.transforms import BboxBase, Transform
from matplotlib.typing import (
CapStyleType,
ColorType,
DrawStyleType,
FillStyleType,
JoinStyleType,
LineStyleType,
MarkerType,
MarkEveryType,
)
from numpy.typing import ArrayLike, NDArray

__all__ = [
"DataType",
Expand All @@ -41,7 +58,7 @@
]

DataType = TypeVar("DataType", xr.DataArray, xr.Dataset)
NormalizableDataType = DataType | str | uuid.UUID
NormalizableDataType: TypeAlias = DataType | str | uuid.UUID

xr_types = (xr.DataArray, xr.Dataset)

Expand Down Expand Up @@ -241,3 +258,108 @@ class SPECTROMETER(ANALYZERINFO, COORDINATES, total=False):

class ARPESAttrs(TypedDict, total=False):
pass


class MPLPlotKwargs(TypedDict, total=False):
scalex: bool
scaley: bool

agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]]
alpha: float | None
animated: bool
antialiased: bool
aa: bool
clip_box: BboxBase | None
clip_on: bool
# clip_path: Path | None color: ColorType
c: ColorType
dash_capstyle: CapStyleType
dash_joinstyle: JoinStyleType
dashes: Sequence[float | None]
data: NDArray[np.float_]
drawstyle: DrawStyleType
figure: Figure
fillstyle: FillStyleType
gapcolor: ColorType | None
gid: str
in_layout: bool
label: str
linestyle: LineStyleType
ls: LineStyleType
linewidth: float
lw: float
marker: MarkerType
markeredgecolor: ColorType
mec: ColorType
markeredgewidth: float
mew: float
markerfacecolor: ColorType
mfc: ColorType
markerfacecoloralt: ColorType
mfcalt: ColorType
markersize: float
ms: float
markevery: MarkEveryType
mouseover: bool
path_effects: list[AbstractPathEffect]
picker: float | Callable[[Artist, Event], tuple[bool, dict]]
pickradius: float
rasterized: bool
sketch_params: tuple[float, float, float]
scale: float
length: float
randomness: float
snap: bool | None
solid_capstyle: CapStyleType
solid_joinstyle: JoinStyleType
url: str
visible: bool
xdata: NDArray[np.float_]
ydata: NDArray[np.float_]
zorder: float


class PColorMeshKwargs(TypedDict, total=False):
agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]]
alpha: float | None
animated: bool
antialiased: bool
aa: bool
array: ArrayLike

capstyle: CapStyleType

clim: tuple[float, float]
clip_box: BboxBase | None
clip_on: bool
cmap: Colormap | str | None
color: ColorType
edgecolor: ColorType
ec: ColorType
facecolor: ColorType
facecolors: ColorType
fc: ColorType
figure: Figure
gid: str
hatch: Literal["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*"]
in_layout: bool
joinstyle: JoinStyleType
label: str
linestyle: LineStyleType
linewidth: float | list[float]
linewidths: float | list[float]
lw: float | list[float]
mouseover: bool
offsets: NDArray[np.float_]
path_effects: list[AbstractPathEffect]
picker: None | bool | float
rasterized: bool
sketch_params: tuple[float, float, float]
scale: float
randomness: float
snap: bool | None
transform: Transform
url: str
urls: list[str | None]
visible: bool
zorder: float
11 changes: 0 additions & 11 deletions arpes/all.py

This file was deleted.

14 changes: 0 additions & 14 deletions arpes/analysis/all.py

This file was deleted.

4 changes: 2 additions & 2 deletions arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl
assert isinstance(fit_kwargs, dict)
data_array = normalize_to_spectrum(data)
mom_dim = next(
d for d in ["kp", "kx", "ky", "kz", "phi", "beta", "theta"] if d in data_array.dims
dim for dim in ["kp", "kx", "ky", "kz", "phi", "beta", "theta"] if dim in data_array.dims
)

results = broadcast_model(
Expand All @@ -61,7 +61,7 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl
)
if mom_dim in {"phi", "beta", "theta"}:
forward = convert_coordinates_to_kspace_forward(data_array)
final_mom = next(d for d in ["kx", "ky", "kp", "kz"] if d in forward)
final_mom = next(dim for dim in ["kx", "ky", "kp", "kz"] if dim in forward)
eVs = results.F.p("a_center").values
kps = [
forward[final_mom].sel(eV=eV, **dict([[mom_dim, ang]]), method="nearest")
Expand Down
15 changes: 8 additions & 7 deletions arpes/analysis/self_energy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Contains self-energy analysis routines."""
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal, TypeAlias

import lmfit as lf
import numpy as np
Expand All @@ -22,8 +22,8 @@
)


BareBandType = xr.DataArray | str | lf.model.ModelResult
DispersionType = xr.DataArray | xr.Dataset
BareBandType: TypeAlias = xr.DataArray | str | lf.model.ModelResult
DispersionType: TypeAlias = xr.DataArray | xr.Dataset


def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray:
Expand Down Expand Up @@ -62,7 +62,7 @@ def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray:
)


def local_fermi_velocity(bare_band: xr.DataArray):
def local_fermi_velocity(bare_band: xr.DataArray) -> float:
"""Calculates the band velocity under assumptions of a linear bare band."""
fitted_model = LinearModel().guess_fit(bare_band)
raw_velocity = fitted_model.params["slope"].value
Expand Down Expand Up @@ -172,9 +172,9 @@ def quasiparticle_mean_free_path(
def to_self_energy(
dispersion: xr.DataArray,
bare_band: BareBandType | None = None,
fermi_velocity: float | None = None,
*,
k_independent: bool = True,
fermi_velocity=None,
) -> xr.Dataset:
r"""Converts MDC fit results into the self energy.
Expand All @@ -198,8 +198,8 @@ def to_self_energy(
Args:
dispersion
bare_band
k_independent
fermi_velocity
k_independent: bool
Returns:
The equivalent self energy from the bare band and the measured dispersion.
Expand All @@ -219,6 +219,7 @@ def to_self_energy(

if fermi_velocity is None:
fermi_velocity = local_fermi_velocity(estimated_bare_band)
assert isinstance(fermi_velocity, float)

imaginary_part = get_peak_parameter(dispersion, "fwhm") / 2
centers = get_peak_parameter(dispersion, "center")
Expand All @@ -243,7 +244,7 @@ def to_self_energy(

def fit_for_self_energy(
data: xr.DataArray,
method="mdc",
method: Literal["mdc", "edc"] = "mdc",
bare_band: BareBandType | None = None,
**kwargs: Incomplete,
) -> xr.Dataset:
Expand Down
5 changes: 3 additions & 2 deletions arpes/endstations/fits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import warnings
from ast import literal_eval
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeAlias

import numpy as np
from numpy import ndarray
from numpy._typing import NDArray

from arpes.trace import traceable
from arpes.utilities.funcutils import collect_leaves, iter_leaves
Expand All @@ -33,7 +34,7 @@
"Z": "z",
}

CoordsDict = dict[str, ndarray]
CoordsDict: TypeAlias = dict[str, NDArray[np.float_]]
Dimension = str


Expand Down
1 change: 1 addition & 0 deletions arpes/endstations/plugin/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

if TYPE_CHECKING:
from _typeshed import Incomplete

from arpes.endstations import SCANDESC
__all__ = ("FallbackEndstation",)

Expand Down
5 changes: 4 additions & 1 deletion arpes/fits/zones.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@


def k_points_residual(
paramters, coords_dataset, high_symmetry_points, dimensionality: int = 2
paramters,
coords_dataset,
high_symmetry_points,
dimensionality: int = 2,
) -> NDArray[np.float_]:
momentum_coordinates = convert_coordinates(coords_dataset)
if dimensionality == 2:
Expand Down
26 changes: 0 additions & 26 deletions arpes/plotting/all.py

This file was deleted.

Loading

0 comments on commit d17f435

Please sign in to comment.