From c5d820e17bfed7d60af18aef8ec796ddb4d306f4 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sat, 23 Mar 2024 09:45:47 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/bootstrap.py | 17 +++++---- src/arpes/fits/__init__.py | 39 +++++++++++++++++++-- src/arpes/plotting/fit_tool/__init__.py | 5 +-- src/arpes/plotting/qt_tool/__init__.py | 3 +- src/arpes/plotting/utils.py | 2 +- src/arpes/utilities/bz.py | 4 +-- src/arpes/utilities/conversion/core.py | 10 +++--- src/arpes/utilities/qt/app.py | 3 +- src/arpes/widgets.py | 46 +++++++++++-------------- 9 files changed, 84 insertions(+), 45 deletions(-) diff --git a/src/arpes/bootstrap.py b/src/arpes/bootstrap.py index f231e0e8..21200591 100644 --- a/src/arpes/bootstrap.py +++ b/src/arpes/bootstrap.py @@ -227,10 +227,6 @@ def bootstrap_counts( class Distribution: DEFAULT_N_SAMPLES = 1000 - def draw_samples(self, n_samples: int = DEFAULT_N_SAMPLES) -> None: - """Draws samples from this distribution.""" - raise NotImplementedError - @dataclass class Normal(Distribution): @@ -244,9 +240,16 @@ class Normal(Distribution): center: float stderr: float - def draw_samples(self, n_samples: int = Distribution.DEFAULT_N_SAMPLES) -> NDArray[np.int_]: + def draw_samples( + self, + n_samples: int = Distribution.DEFAULT_N_SAMPLES, + ) -> NDArray[np.int_]: """Draws samples from this distribution.""" - return scipy.stats.norm.rvs(self.center, scale=self.stderr, size=n_samples) + return scipy.stats.norm.rvs( + self.center, + scale=self.stderr, + size=n_samples, + ) @classmethod def from_param(cls: type, model_param: lf.Model.Parameter) -> Incomplete: @@ -360,7 +363,7 @@ def bootstrapped( for i, arg in enumerate(args) if isinstance(arg, xr.DataArray | xr.Dataset) and i not in skip ] - data_is_arraylike = False + data_is_arraylike: bool = False runs = [] diff --git a/src/arpes/fits/__init__.py b/src/arpes/fits/__init__.py index bc31257c..332e893f 100644 --- a/src/arpes/fits/__init__.py +++ b/src/arpes/fits/__init__.py @@ -2,9 +2,44 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import Literal, TypedDict -from .fit_models import * +from .fit_models.backgrounds import AffineBackgroundModel +from .fit_models.decay import ExponentialDecayCModel, TwoExponentialDecayCModel +from .fit_models.dirac import DiracDispersionModel +from .fit_models.fermi_edge import ( + AffineBroadenedFD, + BandEdgeBGModel, + BandEdgeBModel, + FermiDiracAffGaussModel, + FermiDiracModel, + FermiLorentzianModel, + GStepBModel, + GStepBStandardModel, + GStepBStdevModel, + TwoBandEdgeBModel, + TwoLorEdgeModel, +) +from .fit_models.misc import ( + FermiVelocityRenormalizationModel, + LogRenormalizationModel, + QuadraticModel, +) +from .fit_models.peaks import TwoGaussianModel, TwoLorModel +from .fit_models.two_dimensional import EffectiveMassModel, Gaussian2DModel +from .fit_models.wrapped import ( + ConstantModel, + GaussianModel, + LinearModel, + LogisticModel, + LorentzianModel, + SineModel, + SkewedVoigtModel, + SplitLorentzianModel, + StepModel, + VoigtModel, +) +from .fit_models.x_model_mixin import XModelMixin, gaussian_convolve from .utilities import broadcast_model, result_to_hints NAN_POLICY = Literal["raise", "propagate", "omit"] diff --git a/src/arpes/plotting/fit_tool/__init__.py b/src/arpes/plotting/fit_tool/__init__.py index 11ad2839..9154511a 100644 --- a/src/arpes/plotting/fit_tool/__init__.py +++ b/src/arpes/plotting/fit_tool/__init__.py @@ -17,6 +17,7 @@ from PySide6 import QtCore, QtGui, QtWidgets from PySide6.QtWidgets import QLayout, QWidget +from arpes.constants import TWO_DIMENSION from arpes.fits.utilities import result_to_hints from arpes.plotting.qt_tool.BinningInfoWidget import BinningInfoWidget from arpes.utilities.qt import ( @@ -231,7 +232,7 @@ def configure_image_widgets(self) -> None: The 1D marginal will have a cursor and binning controls on that cursor. """ - if len(self.data.dims) == 2: # noqa: PLR2004 + if len(self.data.dims) == TWO_DIMENSION: self.generate_marginal_for((), 0, 0, "xy", cursors=True, layout=self.content_layout) self.generate_fit_marginal_for( (0, 1), @@ -243,7 +244,7 @@ def configure_image_widgets(self) -> None: ) self.views["xy"].view.setYLink(self.views["fit"].inner_plot) - if len(self.data.dims) == 3: # noqa: PLR2004 + if len(self.data.dims) == TWO_DIMENSION + 1: self.generate_marginal_for((2,), 1, 0, "xy", cursors=True, layout=self.content_layout) self.generate_fit_marginal_for( (0, 1, 2), diff --git a/src/arpes/plotting/qt_tool/__init__.py b/src/arpes/plotting/qt_tool/__init__.py index 8f210789..4da0edab 100644 --- a/src/arpes/plotting/qt_tool/__init__.py +++ b/src/arpes/plotting/qt_tool/__init__.py @@ -17,6 +17,7 @@ from PySide6 import QtCore, QtWidgets from PySide6.QtWidgets import QGridLayout +from arpes.constants import TWO_DIMENSION from arpes.utilities import normalize_to_spectrum from arpes.utilities.qt import ( BasicHelpDialog, @@ -295,7 +296,7 @@ def configure_image_widgets(self) -> None: An additional complexity is that we also handle the cursor registration here. """ - if len(self.data.dims) == 2: # noqa: PLR2004 + if len(self.data.dims) == TWO_DIMENSION: self.generate_marginal_for((), 1, 0, "xy", cursors=True, layout=self.content_layout) self.generate_marginal_for( (1,), diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index af92a676..9924ff9d 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -1362,7 +1362,7 @@ def path_for_plot(desired_path: str | Path) -> Path: filename = ( Path(figure_path) / workspace["name"] - / datetime.datetime.now(tz=datetime.timezone.utc).date().isoformat() + / datetime.datetime.now(tz=datetime.UTC).date().isoformat() / desired_path ) filename = Path(filename).absolute() diff --git a/src/arpes/utilities/bz.py b/src/arpes/utilities/bz.py index 97139109..2c0ba8f8 100644 --- a/src/arpes/utilities/bz.py +++ b/src/arpes/utilities/bz.py @@ -104,7 +104,7 @@ def parse_single_path(path: str) -> list[SpecialPoint]: ToDo: Shold be removed. Use ase. """ # first tokenize - tokens = [name for name in re.split(r"([A-Z][a-z0-9]*(?:\([0-9,\s]+\))?)", path) if name] + tokens:list[str] = [name for name in re.split(r"([A-Z][a-z0-9]*(?:\([0-9,\s]+\))?)", path) if name] # normalize Gamma to G tokens = [token.replace("Gamma", "G") for token in tokens] @@ -147,7 +147,7 @@ def _parse_path(paths: str | list[str]) -> list[list[SpecialPoint]]: """ if isinstance(paths, str): # some manual string work in order to make sure we do not split on commas inside BZ indices - idxs = [] + idxs:list[int] = [] for i, p in enumerate(paths): if p == ",": c = Counter(paths[:i]) diff --git a/src/arpes/utilities/conversion/core.py b/src/arpes/utilities/conversion/core.py index bf89ba89..272d4df8 100644 --- a/src/arpes/utilities/conversion/core.py +++ b/src/arpes/utilities/conversion/core.py @@ -34,6 +34,7 @@ import xarray as xr from scipy.interpolate import RegularGridInterpolator +from arpes.constants import TWO_DIMENSION from arpes.provenance import Provenance, provenance, update_provenance from arpes.utilities import normalize_to_spectrum @@ -166,9 +167,10 @@ def slice_along_path( # noqa: PLR0913 Returns: xr.DataArray containing the interpolated data. """ - if interpolation_points is None: - msg = "You must provide points specifying an interpolation path" - raise ValueError(msg) + assert isinstance( + interpolation_points, + np.ndarray, + ), "You must provide points specifying an interpolation path" parsed_interpolation_points = [ ( @@ -302,7 +304,7 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[ assert isinstance(converted_ds, xr.Dataset) if ( - axis_name in arr.dims and len(parsed_interpolation_points) == 2 # noqa: PLR2004 + axis_name in arr.dims and len(parsed_interpolation_points) == TWO_DIMENSION ) and parsed_interpolation_points[1][axis_name] < parsed_interpolation_points[0][axis_name]: # swap the sign on this axis as a convenience to the caller converted_ds.coords[axis_name].data = -converted_ds.coords[axis_name].data diff --git a/src/arpes/utilities/qt/app.py b/src/arpes/utilities/qt/app.py index 1cfffb36..d49f5fa1 100644 --- a/src/arpes/utilities/qt/app.py +++ b/src/arpes/utilities/qt/app.py @@ -15,6 +15,7 @@ from PySide6 import QtWidgets import arpes.config +from arpes.constants import TWO_DIMENSION from arpes.utilities.ui import CursorRegion from .data_array_image_view import DataArrayImageView, DataArrayPlot @@ -184,7 +185,7 @@ def generate_marginal_for( # noqa: PLR0913 widget.addItem(cursor, ignoreBounds=False) self.connect_cursor(remaining_dims[0], cursor) else: - assert len(remaining_dims) == 2 # noqa: PLR2004 + assert len(remaining_dims) == TWO_DIMENSION widget = DataArrayImageView(name=name) widget.view.setAspectLocked(lock=False) self.views[name] = widget diff --git a/src/arpes/widgets.py b/src/arpes/widgets.py index 116fcd0b..b82e3a2c 100644 --- a/src/arpes/widgets.py +++ b/src/arpes/widgets.py @@ -35,7 +35,7 @@ from collections.abc import Sequence from functools import wraps from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeAlias, TypeVar import matplotlib as mpl import matplotlib.pyplot as plt @@ -506,7 +506,7 @@ def compute_parameters() -> dict: """ renamed = [ {f"{prefix}_{k}": v for k, v in m_setting.items()} - for m_setting, prefix in zip(model_settings, prefixes, strict=True) + for m_setting, prefix in zip(model_settings, prefixes, strict=False) ] return dict(itertools.chain(*[list(d.items()) for d in renamed])) @@ -757,7 +757,7 @@ def kspace_tool( overplot_bz: Callable[[Axes], None] | list[Callable[[Axes], None]] | None = None, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, resolution: dict | None = None, - coords: dict[str, NDArray[np.float_] | xr.DataArray] | None = None, + coords: dict[Literal["kp", "kx", "ky", "kz"], NDArray[np.float_]] | None = None, **kwargs: Incomplete, ) -> CurrentContext: """A utility for assigning coordinate offsets using a live momentum conversion. @@ -777,22 +777,21 @@ def kspace_tool( ValueError: [TODO:description] """ original_data = data - data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) - assert isinstance(data_array, xr.DataArray) - if len(data_array.dims) > TWO_DIMENSION: - data_array = data_array.sel(eV=slice(-0.05, 0.05)).sum("eV", keep_attrs=True) - data_array.coords["eV"] = 0 + assert isinstance(data, xr.DataArray) + if len(data.dims) > TWO_DIMENSION: + data = data.sel(eV=slice(-0.05, 0.05)).sum("eV", keep_attrs=True) + data.coords["eV"] = 0 - if "eV" in data_array.dims: - data_array.S.transpose_to_front("eV") - data_array = data_array.copy(deep=True) + if "eV" in data.dims: + data.S.transpose_to_front("eV") + data = data.copy(deep=True) - ctx: CurrentContext = {"original_data": original_data, "data": data_array, "widgets": []} + ctx: CurrentContext = {"original_data": original_data, "data": data, "widgets": []} arpes.config.CONFIG["CURRENT_CONTEXT"] = ctx gs = gridspec.GridSpec(4, 3) - ax_initial = plt.subplot(gs[0:2, 0:2]) - ax_converted = plt.subplot(gs[2:, 0:2]) + ax_initial, ax_converted = plt.subplot(gs[0:2, 0:2]), plt.subplot(gs[2:, 0:2]) if overplot_bz is not None: if not isinstance(overplot_bz, Sequence): @@ -813,17 +812,14 @@ def kspace_tool( skip_dims = {"x", "X", "y", "Y", "z", "Z", "T"} for dim in skip_dims: - if dim in data_array.dims: - msg = f"Please provide data without the {dim} dimension" - raise ValueError(msg) + assert dim not in data.dims, f"Please provide data without the {dim} dimension" convert_dims = ["theta", "beta", "phi", "psi"] - if "eV" not in data_array.dims: + if "eV" not in data.dims: convert_dims += ["chi"] - if "hv" in data_array.dims: + if "hv" in data.dims: convert_dims += ["hv"] - ang_range = (np.deg2rad(-45), np.deg2rad(45), 0.01) default_ranges = { "eV": [-0.05, 0.05, 0.001], "hv": [-20, 20, 0.5], @@ -833,12 +829,12 @@ def kspace_tool( def update_kspace_plot() -> None: for name, slider in sliders.items(): - data_array.attrs[f"{name}_offset"] = slider.val + data.attrs[f"{name}_offset"] = slider.val with warnings.catch_warnings(): warnings.simplefilter("ignore") converted_view.data = convert_to_kspace( - data_array, + data, bounds=bounds, resolution=resolution, coords=coords, @@ -848,8 +844,8 @@ def update_kspace_plot() -> None: axes = iter(widget_axes) for convert_dim in convert_dims: widget_ax = next(axes) - low, high, delta = default_ranges.get(convert_dim, ang_range) - init = data_array.S.lookup_offset(convert_dim) + low, high, delta = default_ranges.get(convert_dim, (np.deg2rad(-45), np.deg2rad(45), 0.01)) + init = data.S.lookup_offset(convert_dim) sliders[convert_dim] = Slider( widget_ax, convert_dim, @@ -907,7 +903,7 @@ def apply_offsets(event: Event) -> None: data_view = DataArrayView(ax_initial) converted_view = DataArrayView(ax_converted) - data_view.data = data_array + data_view.data = data update_kspace_plot() plt.tight_layout()