Skip to content

Commit

Permalink
💬 Update type hints: Introduce XrTypes as xr.DataArray | xr.Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 8, 2024
1 parent 033f089 commit 8ebd89d
Show file tree
Hide file tree
Showing 29 changed files with 319 additions and 256 deletions.
106 changes: 71 additions & 35 deletions arpes/_typing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Specialized type annotations for use in PyARPES.
In particular, we frequently allow using the `DataType` annotation,
which refers to either an xarray.DataArray or xarray.Dataset.
which refers to either an xarray.DataArray|xarray.Dataset.
Additionally, we often use `NormalizableDataType` which
means essentially anything that can be turned into a dataset,
for instance by loading from the cache using an ID, or which is
for instance by loading from the cache using an ID,|which is
literally already data.
"""

Expand Down Expand Up @@ -49,17 +49,22 @@
MarkEveryType,
)
from numpy.typing import ArrayLike, NDArray
from PySide6.QtCore.Qt import Orientation, WindowType
from PySide6 import QtCore
from PySide6.QtGui import QIcon, QPixmap
from PySide6.QtWidgets import (
QWidget,
)

DataType = TypeVar("DataType", xr.DataArray, xr.Dataset)
NormalizableDataType: TypeAlias = DataType | str | uuid.UUID

XrTypes: TypeAlias = xr.DataArray | xr.Dataset


__all__ = [
"DataType",
"NormalizableDataType",
"xr_types",
"XrTypes",
"SPECTROMETER",
"MOMENTUM",
"EMISSION_ANGLE",
Expand All @@ -69,16 +74,39 @@
"ANALYZERINFO",
]

DataType = TypeVar("DataType", xr.DataArray, xr.Dataset)
NormalizableDataType: TypeAlias = DataType | str | uuid.UUID

xr_types = (xr.DataArray, xr.Dataset)


MOMENTUM = Literal["kp", "kx", "ky", "kz"]
EMISSION_ANGLE = Literal["phi", "psi"]
ANGLE = Literal["alpha", "beta", "chi", "theta"] | EMISSION_ANGLE

LegendLocation = (
Literal[
"best",
0,
"upper right",
1,
"upper left",
2,
"lower left",
3,
"lower right",
4,
"right",
5,
"center left",
6,
"center right",
7,
"lower center",
8,
"upper center",
9,
"center",
10,
]
| tuple[float, float]
)


class KspaceCoords(TypedDict, total=False):
eV: NDArray[np.float_]
Expand Down Expand Up @@ -110,8 +138,8 @@ class CURRENTCONTEXT(TypedDict, total=False):
map_data: Incomplete
selector: Incomplete
integration_region: dict[Incomplete, Incomplete]
original_data: xr.DataArray | xr.Dataset
data: xr.DataArray | xr.Dataset
original_data: XrTypes
data: XrTypes
widgets: list[mpl.widgets.AxesWidget]
points: list[Incomplete]
rect_next: bool
Expand Down Expand Up @@ -316,13 +344,13 @@ class ARPESAttrs(SPECTROMETER, LIGHTSOURCEINFO, SAMPLEINFO, total=False):


class QSliderARGS(TypedDict, total=False):
orientation: Orientation
orientation: QtCore.Qt.Orientation
parent: QWidget | None


class QWidgetARGS(TypedDict, total=False):
parent: QWidget | None
f: WindowType
f: QtCore.Qt.WindowType


class QPushButtonARGS(TypedDict, total=False):
Expand All @@ -334,34 +362,50 @@ class QPushButtonARGS(TypedDict, total=False):
#
# TypedDict for plotting
#
class MPLPlotKwargs(TypedDict, total=False):
scalex: bool
scaley: bool


class MPLPlotKwagsBasic(TypedDict, total=False):
"""Kwargs for Axes.plot & Axes.fill_between."""

agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]]
alpha: float | None
animated: bool
antialiased: bool
aa: bool
antialiased: bool | list[bool]
aa: bool | list[bool]
clip_box: BboxBase | None
clip_on: bool
# clip_path: Path | None color: ColorType
color: ColorType
c: ColorType
dash_capstyle: CapStyleType
dash_joinstyle: JoinStyleType
dashes: Sequence[float | None]
data: NDArray[np.float_]
drawstyle: DrawStyleType
figure: Figure
fillstyle: FillStyleType
gapcolor: ColorType | None
gid: str
in_layout: bool
label: str
linestyle: LineStyleType
ls: LineStyleType
linestyle: LineStyleType
linewidth: float
lw: float
mouseover: bool
path_effects: list[AbstractPathEffect]
pickradius: float
rasterized: bool
sketch_params: tuple[float, float, float]
snap: bool | None
transform: Transform
url: str
visible: bool


class MPLPlotKwargs(MPLPlotKwagsBasic, total=False):
scalex: bool
scaley: bool

dash_capstyle: CapStyleType
dash_joinstyle: JoinStyleType
dashes: Sequence[float | None]
data: NDArray[np.float_]
drawstyle: DrawStyleType
fillstyle: FillStyleType
gapcolor: ColorType | None
marker: MarkerType
markeredgecolor: ColorType
mec: ColorType
Expand All @@ -374,20 +418,12 @@ class MPLPlotKwargs(TypedDict, total=False):
markersize: float
ms: float
markevery: MarkEveryType
mouseover: bool
path_effects: list[AbstractPathEffect]
picker: float | Callable[[Artist, Event], tuple[bool, dict]]
pickradius: float
rasterized: bool
sketch_params: tuple[float, float, float]
scale: float
length: float
randomness: float
snap: bool | None
solid_capstyle: CapStyleType
solid_joinstyle: JoinStyleType
url: str
visible: bool
xdata: NDArray[np.float_]
ydata: NDArray[np.float_]
zorder: float
Expand Down
6 changes: 4 additions & 2 deletions arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import XrTypes

__all__ = (
"fit_bands",
"fit_for_effective_mass",
)


def fit_for_effective_mass(
data: xr.Dataset | xr.DataArray,
data: XrTypes,
fit_kwargs: dict | None = None,
) -> float:
"""Fits for the effective mass in a piece of data.
Expand Down Expand Up @@ -256,7 +258,7 @@ def fit_patterned_bands(
background: bool = True,
interactive: bool = True,
dataset: bool = True,
) -> xr.DataArray | xr.Dataset:
) -> XrTypes:
"""Fits bands and determines dispersion in some region of a spectrum.
The dimensions of the dataset are partitioned into three types:
Expand Down
1 change: 0 additions & 1 deletion arpes/analysis/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def rebin(
Returns:
The rebinned data.
"""
assert isinstance(data, xr.DataArray | xr.Dataset)
if bin_width is None:
bin_width = {}
for k in kwargs:
Expand Down
12 changes: 6 additions & 6 deletions arpes/analysis/pocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from _typeshed import Incomplete

from arpes._typing import DataType
from arpes._typing import DataType, XrTypes
__all__ = (
"curves_along_pocket",
"edcs_along_pocket",
Expand Down Expand Up @@ -84,7 +84,7 @@ def pocket_parameters(

@update_provenance("Collect EDCs projected at an angle from pocket")
def radial_edcs_along_pocket(
data: xr.DataArray | xr.Dataset,
data: XrTypes,
angle: float,
radii: tuple[float, float] = (0.0, 5.0),
n_points: int = 0,
Expand All @@ -101,7 +101,7 @@ def radial_edcs_along_pocket(
>>> radial_edcs_along_pocket(spectrum, np.pi / 4, (1, 4), phi=0.1, beta=0)
Args:
data (xr.DataArray | xr.Dataset): ARPES Spectrum.
data (XrTypes): ARPES Spectrum.
angle (float): Angle along the FS to cut against.
radii (tuple[float, float]): The min and max for the angle/momentum equivalent radial
coordinate.
Expand Down Expand Up @@ -158,7 +158,7 @@ def radial_edcs_along_pocket(


def curves_along_pocket(
data: xr.DataArray | xr.Dataset,
data: XrTypes,
n_points: int = 0,
inner_radius: float = 0.0,
outer_radius: float = 5.0,
Expand Down Expand Up @@ -237,7 +237,7 @@ def slice_at_angle(theta: float) -> xr.DataArray:


def find_kf_by_mdc(
slice_data: xr.Dataset | xr.DataArray,
slice_data: XrTypes,
offset: float = 0,
**kwargs: Incomplete,
) -> float:
Expand Down Expand Up @@ -271,7 +271,7 @@ def find_kf_by_mdc(

@update_provenance("Collect EDCs around pocket edge")
def edcs_along_pocket(
data: xr.DataArray | xr.Dataset,
data: XrTypes,
kf_method: Callable[..., float] | None = None,
select_radius: dict[str, float] | None = None,
sel: dict[str, slice] | None = None,
Expand Down
1 change: 0 additions & 1 deletion arpes/analysis/self_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


BareBandType: TypeAlias = xr.DataArray | str | lf.model.ModelResult
DispersionType: TypeAlias = xr.DataArray | xr.Dataset


def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray:
Expand Down
4 changes: 2 additions & 2 deletions arpes/analysis/tarpes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
from arpes._typing import DataType
from arpes._typing import DataType, XrTypes

__all__ = ("find_t0", "relative_change", "normalized_relative_change")


@update_provenance("Normalized subtraction map")
def normalized_relative_change(
data: xr.DataArray | xr.Dataset,
data: XrTypes,
t0: float | None = None,
buffer: float = 0.3,
*,
Expand Down
7 changes: 4 additions & 3 deletions arpes/corrections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
earlier in a dataset which can be used to furnish equivalent references.
"""

from __future__ import annotations

from collections import OrderedDict
Expand All @@ -28,7 +29,7 @@
)

if TYPE_CHECKING:
from arpes._typing import DataType
from arpes._typing import XrTypes

__all__ = (
"reference_key",
Expand All @@ -46,14 +47,14 @@ def __hash__(self):
return hash(frozenset(self.items()))


def reference_key(data: DataType) -> HashableDict:
def reference_key(data: XrTypes) -> HashableDict:
"""Calculates a key/hash for data determining reference/correction equality."""
data_array = normalize_to_dataset(data)
assert isinstance(data_array, xr.DataArray)
return HashableDict(data_array.S.reference_settings)


def correction_from_reference_set(data: DataType, reference_set):
def correction_from_reference_set(data: XrTypes, reference_set):
"""Determines which correction to use from a set of references."""
data_array = normalize_to_dataset(data)
correction = None
Expand Down
7 changes: 6 additions & 1 deletion arpes/corrections/background.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""For estimating the above Fermi level incoherent background."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand All @@ -16,7 +17,11 @@


@update_provenance("Remove incoherent background from above Fermi level")
def remove_incoherent_background(data: DataType, *, set_zero: bool = True) -> xr.DataArray:
def remove_incoherent_background(
data: DataType,
*,
set_zero: bool = True,
) -> xr.DataArray:
"""Removes counts above the Fermi level.
Sometimes spectra are contaminated by data above the Fermi level for
Expand Down
Loading

0 comments on commit 8ebd89d

Please sign in to comment.