diff --git a/README.rst b/README.rst
index b75f5ff6..c900d5bc 100644
--- a/README.rst
+++ b/README.rst
@@ -20,7 +20,9 @@
:target: https://github.com/psf/black
.. |code fromat| image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
:target: https://github.com/astral-sh/ruff
-
+.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/rye/main/artwork/badge.json
+ :target: https://rye-up.com
+ :alt: Rye
PyARPES corrected (V4)
=======================
diff --git a/src/arpes/analysis/decomposition.py b/src/arpes/analysis/decomposition.py
index 05277752..d02ee693 100644
--- a/src/arpes/analysis/decomposition.py
+++ b/src/arpes/analysis/decomposition.py
@@ -3,17 +3,19 @@
from __future__ import annotations
from functools import wraps
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import xarray as xr
+from sklearn.decomposition import FactorAnalysis, FastICA
from arpes.constants import TWO_DIMENSION
from arpes.provenance import PROVENANCE, provenance
from arpes.utilities import normalize_to_spectrum
if TYPE_CHECKING:
+ import numpy as np
import sklearn
- from _typeshed import Incomplete
+ from numpy.typing import NDArray
__all__ = (
"nmf_along",
@@ -23,13 +25,69 @@
)
+class PCAParam(TypedDict, total=False):
+ n_composition: float | Literal["mle", "auto"] | None
+ copy: bool
+ whiten: str | bool
+ svd_solver: Literal["auto", "full", "arpack", "randomiozed"]
+ tol: float
+ iterated_power: int | Literal["auto"]
+ n_oversamples: int
+ power_interation_normalizer: Literal["auto", "QR", "LU", "none"]
+ random_state: int | None
+
+
+class FastICAParam(TypedDict, total=False):
+ n_composition: float | None
+ algorithm: Literal["Parallel", "deflation"]
+ whiten: bool | Literal["unit-variance", "arbitrary-variance"]
+ fun: Literal["logosh", "exp", "cube"]
+ fun_args: dict[str, float] | None
+ max_iter: int
+ tol: float
+ w_int: NDArray[np.float_]
+ whiten_solver: Literal["eigh", "svd"]
+ random_state: int | None
+
+
+class NMFParam(TypedDict, total=False):
+ n_composition: int | Literal["auto"] | None
+ init: Literal["random", "nndsvd", "nndsvda", "nndsvdar", "custom", None]
+ solver: Literal["cd", "mu"]
+ beta_loss: float | Literal["frobenius", "kullback-leibler", "itakura-saito"]
+ tol: float
+ max_iter: int
+ random_state: int | None
+ alpha_W: float
+ alpha_H: float
+ l1_ratio: float
+ verbose: int
+ shuffle: bool
+
+
+class FactorAnalysisParam(TypedDict, total=False):
+ n_composition: int | None
+ tol: float
+ copy: bool
+ max_iter: int
+ noise_variance_init: NDArray[np.float_] | None
+ svd_method: Literal["lapack", "randomized"]
+ iterated_power: int
+ rotation: Literal["varimax", "quartimax", None]
+ random_state: int | None
+
+
+class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam):
+ pass
+
+
def decomposition_along(
data: xr.DataArray,
axes: list[str],
*,
decomposition_cls: type[sklearn.decomposition],
correlation: bool = False,
- **kwargs: Incomplete,
+ **kwargs: Unpack[DecompositionParam],
) -> tuple[xr.DataArray, sklearn.base.BaseEstimator]:
"""Change the basis of multidimensional data according to `sklearn` decomposition classes.
@@ -119,8 +177,8 @@ def decomposition_along(
@wraps(decomposition_along)
def pca_along(
- *args: Incomplete,
- **kwargs: Incomplete,
+ *args: xr.DataArray | list[str],
+ **kwargs: Unpack[PCAParam],
) -> tuple[xr.DataArray, sklearn.decomposition.PCA]:
"""Specializes `decomposition_along` with `sklearn.decomposition.PCA`."""
from sklearn.decomposition import PCA
@@ -130,30 +188,26 @@ def pca_along(
@wraps(decomposition_along)
def factor_analysis_along(
- *args: Incomplete,
- **kwargs: Incomplete,
+ *args: xr.DataArray | list[str],
+ **kwargs: Unpack[FactorAnalysisParam],
) -> tuple[xr.DataArray, sklearn.decomposition.FactorAnalysis]:
"""Specializes `decomposition_along` with `sklearn.decomposition.FactorAnalysis`."""
- from sklearn.decomposition import FactorAnalysis
-
return decomposition_along(*args, **kwargs, decomposition_cls=FactorAnalysis)
@wraps(decomposition_along)
def ica_along(
- *args: Incomplete,
- **kwargs: Incomplete,
+ *args: xr.DataArray | list[str],
+ **kwargs: Unpack[FastICAParam],
) -> tuple[xr.DataArray, sklearn.decomposition.FastICA]:
"""Specializes `decomposition_along` with `sklearn.decomposition.FastICA`."""
- from sklearn.decomposition import FastICA
-
return decomposition_along(*args, **kwargs, decomposition_cls=FastICA)
@wraps(decomposition_along)
def nmf_along(
- *args: Incomplete,
- **kwargs: Incomplete,
+ *args: xr.DataArray | list[str],
+ **kwargs: Unpack[NMFParam],
) -> tuple[xr.DataArray, sklearn.decomposition.NMF]:
"""Specializes `decomposition_along` with `sklearn.decomposition.NMF`."""
from sklearn.decomposition import NMF
diff --git a/src/arpes/bootstrap.py b/src/arpes/bootstrap.py
index ffa00be9..f697af7a 100644
--- a/src/arpes/bootstrap.py
+++ b/src/arpes/bootstrap.py
@@ -19,7 +19,7 @@
import random
from dataclasses import dataclass
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
import numpy as np
import scipy.stats
@@ -253,7 +253,11 @@ def from_param(cls: type, model_param: lf.Model.Parameter):
return cls(center=model_param.value, stderr=model_param.stderr)
-def propagate_errors(f: Callable) -> Callable:
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def propagate_errors(f: Callable[P, R]) -> Callable[P, R]:
"""A decorator which provides transparent propagation of statistical errors.
The way that this is accommodated is that the inner function is turned into one which
@@ -270,7 +274,7 @@ def propagate_errors(f: Callable) -> Callable:
"""
@functools.wraps(f)
- def operates_on_distributions(*args: Incomplete, **kwargs: Incomplete):
+ def operates_on_distributions(*args: P.args, **kwargs: P.kwargs) -> R:
exclude = set(
[i for i, arg in enumerate(args) if not isinstance(arg, Distribution)]
+ [k for k, arg in kwargs.items() if not isinstance(arg, Distribution)],
diff --git a/src/arpes/plotting/bands.py b/src/arpes/plotting/bands.py
index 9ce04046..ffec9068 100644
--- a/src/arpes/plotting/bands.py
+++ b/src/arpes/plotting/bands.py
@@ -17,7 +17,7 @@
from matplotlib.image import AxesImage
- from arpes._typing import DataType, PColorMeshKwargs
+ from arpes._typing import PColorMeshKwargs, XrTypes
from arpes.models.band import Band
__all__ = ("plot_with_bands",)
@@ -25,7 +25,7 @@
@save_plot_provenance
def plot_with_bands(
- data: DataType,
+ data: XrTypes,
bands: Sequence[Band],
title: str = "",
ax: Axes | None = None,
diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py
index 2ca316a8..be76abde 100644
--- a/src/arpes/plotting/utils.py
+++ b/src/arpes/plotting/utils.py
@@ -1261,7 +1261,7 @@ def load_data_for_figure(p: str | Path) -> None:
def savefig(
desired_path: str | Path,
dpi: int = 400,
- data: list[DataType] | tuple[DataType, ...] | set[DataType] | None = None,
+ data: list[XrTypes] | tuple[XrTypes, ...] | set[XrTypes] | None = None,
save_data=None,
*,
paper: bool = False,
@@ -1493,7 +1493,7 @@ def unit_for_dim(dim_name: str, *, escaped: bool = True) -> str:
return unit
-def label_for_colorbar(data: DataType) -> str:
+def label_for_colorbar(data: XrTypes) -> str:
"""Returns an appropriate label for an ARPES intensity colorbar."""
if not data.S.is_differentiated:
return r"Spectrum Intensity (arb.)"
diff --git a/src/arpes/utilities/ui.py b/src/arpes/utilities/ui.py
index ce7a8139..95e53d92 100644
--- a/src/arpes/utilities/ui.py
+++ b/src/arpes/utilities/ui.py
@@ -434,7 +434,6 @@ def numeric_input(
input_type: type = float,
*args: Incomplete,
validator_settings: dict[str, float] | None = None,
- **kwargs: Incomplete,
) -> QWidget:
"""A numeric input with input validation."""
validators = {
diff --git a/src/arpes/workflow.py b/src/arpes/workflow.py
index 7ba597a7..0911c88b 100644
--- a/src/arpes/workflow.py
+++ b/src/arpes/workflow.py
@@ -80,7 +80,6 @@ def with_workspace(f: Callable[P, R]) -> Callable[P, R]:
@wraps(f)
def wrapped_with_workspace(
*args: P.args,
- workspace_name: str = "",
**kwargs: P.kwargs,
) -> R:
"""[TODO:summary].
@@ -90,6 +89,7 @@ def wrapped_with_workspace(
workspace (str | None): [TODO:description]
kwargs: [TODO:description]
"""
+ workspace_name: str = kwargs.pop("workspace_name", "")
with WorkspaceManager(workspace_name=workspace_name):
import arpes.config
diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py
index 131041b2..7ffb44d0 100644
--- a/src/arpes/xarray_extensions.py
+++ b/src/arpes/xarray_extensions.py
@@ -566,7 +566,7 @@ def select_around(
radius: dict[Hashable, float] | float,
*,
mode: Literal["sum", "mean"] = "sum",
- **kwargs: Incomplete,
+ **kwargs: float,
) -> xr.DataArray:
"""Selects and integrates a region around a one dimensional point.
@@ -1731,8 +1731,7 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]:
min_hv = float(v.min())
max_hv = float(v.max())
transformed_dict[k] = (
- f" from {min_hv} "
- f" to {max_hv} eV"
+ f" from {min_hv} to {max_hv} eV"
)
elif isinstance(v, float) and not np.isnan(v):
transformed_dict[k] = f"{v} eV"