Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 27, 2024
1 parent b6ff770 commit e0d6485
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 28 deletions.
4 changes: 3 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
:target: https://github.com/psf/black
.. |code fromat| image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
:target: https://github.com/astral-sh/ruff

.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/rye/main/artwork/badge.json
:target: https://rye-up.com
:alt: Rye
PyARPES corrected (V4)
=======================

Expand Down
84 changes: 69 additions & 15 deletions src/arpes/analysis/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack

import xarray as xr
from sklearn.decomposition import FactorAnalysis, FastICA

from arpes.constants import TWO_DIMENSION
from arpes.provenance import PROVENANCE, provenance
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
import numpy as np
import sklearn
from _typeshed import Incomplete
from numpy.typing import NDArray

__all__ = (
"nmf_along",
Expand All @@ -23,13 +25,69 @@
)


class PCAParam(TypedDict, total=False):
n_composition: float | Literal["mle", "auto"] | None
copy: bool
whiten: str | bool
svd_solver: Literal["auto", "full", "arpack", "randomiozed"]
tol: float
iterated_power: int | Literal["auto"]
n_oversamples: int
power_interation_normalizer: Literal["auto", "QR", "LU", "none"]
random_state: int | None


class FastICAParam(TypedDict, total=False):
n_composition: float | None
algorithm: Literal["Parallel", "deflation"]
whiten: bool | Literal["unit-variance", "arbitrary-variance"]
fun: Literal["logosh", "exp", "cube"]
fun_args: dict[str, float] | None
max_iter: int
tol: float
w_int: NDArray[np.float_]
whiten_solver: Literal["eigh", "svd"]
random_state: int | None


class NMFParam(TypedDict, total=False):
n_composition: int | Literal["auto"] | None
init: Literal["random", "nndsvd", "nndsvda", "nndsvdar", "custom", None]
solver: Literal["cd", "mu"]
beta_loss: float | Literal["frobenius", "kullback-leibler", "itakura-saito"]
tol: float
max_iter: int
random_state: int | None
alpha_W: float
alpha_H: float
l1_ratio: float
verbose: int
shuffle: bool


class FactorAnalysisParam(TypedDict, total=False):
n_composition: int | None
tol: float
copy: bool
max_iter: int
noise_variance_init: NDArray[np.float_] | None
svd_method: Literal["lapack", "randomized"]
iterated_power: int
rotation: Literal["varimax", "quartimax", None]
random_state: int | None


class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam):
pass


def decomposition_along(
data: xr.DataArray,
axes: list[str],
*,
decomposition_cls: type[sklearn.decomposition],
correlation: bool = False,
**kwargs: Incomplete,
**kwargs: Unpack[DecompositionParam],
) -> tuple[xr.DataArray, sklearn.base.BaseEstimator]:
"""Change the basis of multidimensional data according to `sklearn` decomposition classes.
Expand Down Expand Up @@ -119,8 +177,8 @@ def decomposition_along(

@wraps(decomposition_along)
def pca_along(
*args: Incomplete,
**kwargs: Incomplete,
*args: xr.DataArray | list[str],
**kwargs: Unpack[PCAParam],
) -> tuple[xr.DataArray, sklearn.decomposition.PCA]:
"""Specializes `decomposition_along` with `sklearn.decomposition.PCA`."""
from sklearn.decomposition import PCA
Expand All @@ -130,30 +188,26 @@ def pca_along(

@wraps(decomposition_along)
def factor_analysis_along(
*args: Incomplete,
**kwargs: Incomplete,
*args: xr.DataArray | list[str],
**kwargs: Unpack[FactorAnalysisParam],
) -> tuple[xr.DataArray, sklearn.decomposition.FactorAnalysis]:
"""Specializes `decomposition_along` with `sklearn.decomposition.FactorAnalysis`."""
from sklearn.decomposition import FactorAnalysis

return decomposition_along(*args, **kwargs, decomposition_cls=FactorAnalysis)


@wraps(decomposition_along)
def ica_along(
*args: Incomplete,
**kwargs: Incomplete,
*args: xr.DataArray | list[str],
**kwargs: Unpack[FastICAParam],
) -> tuple[xr.DataArray, sklearn.decomposition.FastICA]:
"""Specializes `decomposition_along` with `sklearn.decomposition.FastICA`."""
from sklearn.decomposition import FastICA

return decomposition_along(*args, **kwargs, decomposition_cls=FastICA)


@wraps(decomposition_along)
def nmf_along(
*args: Incomplete,
**kwargs: Incomplete,
*args: xr.DataArray | list[str],
**kwargs: Unpack[NMFParam],
) -> tuple[xr.DataArray, sklearn.decomposition.NMF]:
"""Specializes `decomposition_along` with `sklearn.decomposition.NMF`."""
from sklearn.decomposition import NMF
Expand Down
10 changes: 7 additions & 3 deletions src/arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
from dataclasses import dataclass
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar

import numpy as np
import scipy.stats
Expand Down Expand Up @@ -253,7 +253,11 @@ def from_param(cls: type, model_param: lf.Model.Parameter):
return cls(center=model_param.value, stderr=model_param.stderr)


def propagate_errors(f: Callable) -> Callable:
P = ParamSpec("P")
R = TypeVar("R")


def propagate_errors(f: Callable[P, R]) -> Callable[P, R]:
"""A decorator which provides transparent propagation of statistical errors.
The way that this is accommodated is that the inner function is turned into one which
Expand All @@ -270,7 +274,7 @@ def propagate_errors(f: Callable) -> Callable:
"""

@functools.wraps(f)
def operates_on_distributions(*args: Incomplete, **kwargs: Incomplete):
def operates_on_distributions(*args: P.args, **kwargs: P.kwargs) -> R:
exclude = set(
[i for i, arg in enumerate(args) if not isinstance(arg, Distribution)]
+ [k for k, arg in kwargs.items() if not isinstance(arg, Distribution)],
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/plotting/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

from matplotlib.image import AxesImage

from arpes._typing import DataType, PColorMeshKwargs
from arpes._typing import PColorMeshKwargs, XrTypes
from arpes.models.band import Band

__all__ = ("plot_with_bands",)


@save_plot_provenance
def plot_with_bands(
data: DataType,
data: XrTypes,
bands: Sequence[Band],
title: str = "",
ax: Axes | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ def load_data_for_figure(p: str | Path) -> None:
def savefig(
desired_path: str | Path,
dpi: int = 400,
data: list[DataType] | tuple[DataType, ...] | set[DataType] | None = None,
data: list[XrTypes] | tuple[XrTypes, ...] | set[XrTypes] | None = None,
save_data=None,
*,
paper: bool = False,
Expand Down Expand Up @@ -1493,7 +1493,7 @@ def unit_for_dim(dim_name: str, *, escaped: bool = True) -> str:
return unit


def label_for_colorbar(data: DataType) -> str:
def label_for_colorbar(data: XrTypes) -> str:
"""Returns an appropriate label for an ARPES intensity colorbar."""
if not data.S.is_differentiated:
return r"Spectrum Intensity (arb.)"
Expand Down
1 change: 0 additions & 1 deletion src/arpes/utilities/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,6 @@ def numeric_input(
input_type: type = float,
*args: Incomplete,
validator_settings: dict[str, float] | None = None,
**kwargs: Incomplete,
) -> QWidget:
"""A numeric input with input validation."""
validators = {
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def with_workspace(f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def wrapped_with_workspace(
*args: P.args,
workspace_name: str = "",
**kwargs: P.kwargs,
) -> R:
"""[TODO:summary].
Expand All @@ -90,6 +89,7 @@ def wrapped_with_workspace(
workspace (str | None): [TODO:description]
kwargs: [TODO:description]
"""
workspace_name: str = kwargs.pop("workspace_name", "")
with WorkspaceManager(workspace_name=workspace_name):
import arpes.config

Expand Down
5 changes: 2 additions & 3 deletions src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def select_around(
radius: dict[Hashable, float] | float,
*,
mode: Literal["sum", "mean"] = "sum",
**kwargs: Incomplete,
**kwargs: float,
) -> xr.DataArray:
"""Selects and integrates a region around a one dimensional point.
Expand Down Expand Up @@ -1731,8 +1731,7 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]:
min_hv = float(v.min())
max_hv = float(v.max())
transformed_dict[k] = (
f"<strong> from </strong> {min_hv} "
f"<strong> to </strong> {max_hv} eV"
f"<strong> from </strong> {min_hv} <strong> to </strong> {max_hv} eV"
)
elif isinstance(v, float) and not np.isnan(v):
transformed_dict[k] = f"{v} eV"
Expand Down

0 comments on commit e0d6485

Please sign in to comment.