diff --git a/README.rst b/README.rst index b75f5ff6..c900d5bc 100644 --- a/README.rst +++ b/README.rst @@ -20,7 +20,9 @@ :target: https://github.com/psf/black .. |code fromat| image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json :target: https://github.com/astral-sh/ruff - +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/rye/main/artwork/badge.json + :target: https://rye-up.com + :alt: Rye PyARPES corrected (V4) ======================= diff --git a/src/arpes/analysis/decomposition.py b/src/arpes/analysis/decomposition.py index 05277752..d02ee693 100644 --- a/src/arpes/analysis/decomposition.py +++ b/src/arpes/analysis/decomposition.py @@ -3,17 +3,19 @@ from __future__ import annotations from functools import wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import xarray as xr +from sklearn.decomposition import FactorAnalysis, FastICA from arpes.constants import TWO_DIMENSION from arpes.provenance import PROVENANCE, provenance from arpes.utilities import normalize_to_spectrum if TYPE_CHECKING: + import numpy as np import sklearn - from _typeshed import Incomplete + from numpy.typing import NDArray __all__ = ( "nmf_along", @@ -23,13 +25,69 @@ ) +class PCAParam(TypedDict, total=False): + n_composition: float | Literal["mle", "auto"] | None + copy: bool + whiten: str | bool + svd_solver: Literal["auto", "full", "arpack", "randomiozed"] + tol: float + iterated_power: int | Literal["auto"] + n_oversamples: int + power_interation_normalizer: Literal["auto", "QR", "LU", "none"] + random_state: int | None + + +class FastICAParam(TypedDict, total=False): + n_composition: float | None + algorithm: Literal["Parallel", "deflation"] + whiten: bool | Literal["unit-variance", "arbitrary-variance"] + fun: Literal["logosh", "exp", "cube"] + fun_args: dict[str, float] | None + max_iter: int + tol: float + w_int: NDArray[np.float_] + whiten_solver: Literal["eigh", "svd"] + random_state: int | None + + +class NMFParam(TypedDict, total=False): + n_composition: int | Literal["auto"] | None + init: Literal["random", "nndsvd", "nndsvda", "nndsvdar", "custom", None] + solver: Literal["cd", "mu"] + beta_loss: float | Literal["frobenius", "kullback-leibler", "itakura-saito"] + tol: float + max_iter: int + random_state: int | None + alpha_W: float + alpha_H: float + l1_ratio: float + verbose: int + shuffle: bool + + +class FactorAnalysisParam(TypedDict, total=False): + n_composition: int | None + tol: float + copy: bool + max_iter: int + noise_variance_init: NDArray[np.float_] | None + svd_method: Literal["lapack", "randomized"] + iterated_power: int + rotation: Literal["varimax", "quartimax", None] + random_state: int | None + + +class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam): + pass + + def decomposition_along( data: xr.DataArray, axes: list[str], *, decomposition_cls: type[sklearn.decomposition], correlation: bool = False, - **kwargs: Incomplete, + **kwargs: Unpack[DecompositionParam], ) -> tuple[xr.DataArray, sklearn.base.BaseEstimator]: """Change the basis of multidimensional data according to `sklearn` decomposition classes. @@ -119,8 +177,8 @@ def decomposition_along( @wraps(decomposition_along) def pca_along( - *args: Incomplete, - **kwargs: Incomplete, + *args: xr.DataArray | list[str], + **kwargs: Unpack[PCAParam], ) -> tuple[xr.DataArray, sklearn.decomposition.PCA]: """Specializes `decomposition_along` with `sklearn.decomposition.PCA`.""" from sklearn.decomposition import PCA @@ -130,30 +188,26 @@ def pca_along( @wraps(decomposition_along) def factor_analysis_along( - *args: Incomplete, - **kwargs: Incomplete, + *args: xr.DataArray | list[str], + **kwargs: Unpack[FactorAnalysisParam], ) -> tuple[xr.DataArray, sklearn.decomposition.FactorAnalysis]: """Specializes `decomposition_along` with `sklearn.decomposition.FactorAnalysis`.""" - from sklearn.decomposition import FactorAnalysis - return decomposition_along(*args, **kwargs, decomposition_cls=FactorAnalysis) @wraps(decomposition_along) def ica_along( - *args: Incomplete, - **kwargs: Incomplete, + *args: xr.DataArray | list[str], + **kwargs: Unpack[FastICAParam], ) -> tuple[xr.DataArray, sklearn.decomposition.FastICA]: """Specializes `decomposition_along` with `sklearn.decomposition.FastICA`.""" - from sklearn.decomposition import FastICA - return decomposition_along(*args, **kwargs, decomposition_cls=FastICA) @wraps(decomposition_along) def nmf_along( - *args: Incomplete, - **kwargs: Incomplete, + *args: xr.DataArray | list[str], + **kwargs: Unpack[NMFParam], ) -> tuple[xr.DataArray, sklearn.decomposition.NMF]: """Specializes `decomposition_along` with `sklearn.decomposition.NMF`.""" from sklearn.decomposition import NMF diff --git a/src/arpes/bootstrap.py b/src/arpes/bootstrap.py index ffa00be9..f697af7a 100644 --- a/src/arpes/bootstrap.py +++ b/src/arpes/bootstrap.py @@ -19,7 +19,7 @@ import random from dataclasses import dataclass from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar import numpy as np import scipy.stats @@ -253,7 +253,11 @@ def from_param(cls: type, model_param: lf.Model.Parameter): return cls(center=model_param.value, stderr=model_param.stderr) -def propagate_errors(f: Callable) -> Callable: +P = ParamSpec("P") +R = TypeVar("R") + + +def propagate_errors(f: Callable[P, R]) -> Callable[P, R]: """A decorator which provides transparent propagation of statistical errors. The way that this is accommodated is that the inner function is turned into one which @@ -270,7 +274,7 @@ def propagate_errors(f: Callable) -> Callable: """ @functools.wraps(f) - def operates_on_distributions(*args: Incomplete, **kwargs: Incomplete): + def operates_on_distributions(*args: P.args, **kwargs: P.kwargs) -> R: exclude = set( [i for i, arg in enumerate(args) if not isinstance(arg, Distribution)] + [k for k, arg in kwargs.items() if not isinstance(arg, Distribution)], diff --git a/src/arpes/plotting/bands.py b/src/arpes/plotting/bands.py index 9ce04046..ffec9068 100644 --- a/src/arpes/plotting/bands.py +++ b/src/arpes/plotting/bands.py @@ -17,7 +17,7 @@ from matplotlib.image import AxesImage - from arpes._typing import DataType, PColorMeshKwargs + from arpes._typing import PColorMeshKwargs, XrTypes from arpes.models.band import Band __all__ = ("plot_with_bands",) @@ -25,7 +25,7 @@ @save_plot_provenance def plot_with_bands( - data: DataType, + data: XrTypes, bands: Sequence[Band], title: str = "", ax: Axes | None = None, diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index 2ca316a8..be76abde 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -1261,7 +1261,7 @@ def load_data_for_figure(p: str | Path) -> None: def savefig( desired_path: str | Path, dpi: int = 400, - data: list[DataType] | tuple[DataType, ...] | set[DataType] | None = None, + data: list[XrTypes] | tuple[XrTypes, ...] | set[XrTypes] | None = None, save_data=None, *, paper: bool = False, @@ -1493,7 +1493,7 @@ def unit_for_dim(dim_name: str, *, escaped: bool = True) -> str: return unit -def label_for_colorbar(data: DataType) -> str: +def label_for_colorbar(data: XrTypes) -> str: """Returns an appropriate label for an ARPES intensity colorbar.""" if not data.S.is_differentiated: return r"Spectrum Intensity (arb.)" diff --git a/src/arpes/utilities/ui.py b/src/arpes/utilities/ui.py index ce7a8139..95e53d92 100644 --- a/src/arpes/utilities/ui.py +++ b/src/arpes/utilities/ui.py @@ -434,7 +434,6 @@ def numeric_input( input_type: type = float, *args: Incomplete, validator_settings: dict[str, float] | None = None, - **kwargs: Incomplete, ) -> QWidget: """A numeric input with input validation.""" validators = { diff --git a/src/arpes/workflow.py b/src/arpes/workflow.py index 7ba597a7..0911c88b 100644 --- a/src/arpes/workflow.py +++ b/src/arpes/workflow.py @@ -80,7 +80,6 @@ def with_workspace(f: Callable[P, R]) -> Callable[P, R]: @wraps(f) def wrapped_with_workspace( *args: P.args, - workspace_name: str = "", **kwargs: P.kwargs, ) -> R: """[TODO:summary]. @@ -90,6 +89,7 @@ def wrapped_with_workspace( workspace (str | None): [TODO:description] kwargs: [TODO:description] """ + workspace_name: str = kwargs.pop("workspace_name", "") with WorkspaceManager(workspace_name=workspace_name): import arpes.config diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 131041b2..7ffb44d0 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -566,7 +566,7 @@ def select_around( radius: dict[Hashable, float] | float, *, mode: Literal["sum", "mean"] = "sum", - **kwargs: Incomplete, + **kwargs: float, ) -> xr.DataArray: """Selects and integrates a region around a one dimensional point. @@ -1731,8 +1731,7 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]: min_hv = float(v.min()) max_hv = float(v.max()) transformed_dict[k] = ( - f" from {min_hv} " - f" to {max_hv} eV" + f" from {min_hv} to {max_hv} eV" ) elif isinstance(v, float) and not np.isnan(v): transformed_dict[k] = f"{v} eV"