Skip to content

Commit

Permalink
💬 update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Mar 23, 2024
1 parent 2db5eef commit c5d820e
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 45 deletions.
17 changes: 10 additions & 7 deletions src/arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,6 @@ def bootstrap_counts(
class Distribution:
DEFAULT_N_SAMPLES = 1000

def draw_samples(self, n_samples: int = DEFAULT_N_SAMPLES) -> None:
"""Draws samples from this distribution."""
raise NotImplementedError


@dataclass
class Normal(Distribution):
Expand All @@ -244,9 +240,16 @@ class Normal(Distribution):
center: float
stderr: float

def draw_samples(self, n_samples: int = Distribution.DEFAULT_N_SAMPLES) -> NDArray[np.int_]:
def draw_samples(
self,
n_samples: int = Distribution.DEFAULT_N_SAMPLES,
) -> NDArray[np.int_]:
"""Draws samples from this distribution."""
return scipy.stats.norm.rvs(self.center, scale=self.stderr, size=n_samples)
return scipy.stats.norm.rvs(
self.center,
scale=self.stderr,
size=n_samples,
)

@classmethod
def from_param(cls: type, model_param: lf.Model.Parameter) -> Incomplete:
Expand Down Expand Up @@ -360,7 +363,7 @@ def bootstrapped(
for i, arg in enumerate(args)
if isinstance(arg, xr.DataArray | xr.Dataset) and i not in skip
]
data_is_arraylike = False
data_is_arraylike: bool = False

runs = []

Expand Down
39 changes: 37 additions & 2 deletions src/arpes/fits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,44 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, TypedDict
from typing import Literal, TypedDict

from .fit_models import *
from .fit_models.backgrounds import AffineBackgroundModel
from .fit_models.decay import ExponentialDecayCModel, TwoExponentialDecayCModel
from .fit_models.dirac import DiracDispersionModel
from .fit_models.fermi_edge import (
AffineBroadenedFD,
BandEdgeBGModel,
BandEdgeBModel,
FermiDiracAffGaussModel,
FermiDiracModel,
FermiLorentzianModel,
GStepBModel,
GStepBStandardModel,
GStepBStdevModel,
TwoBandEdgeBModel,
TwoLorEdgeModel,
)
from .fit_models.misc import (
FermiVelocityRenormalizationModel,
LogRenormalizationModel,
QuadraticModel,
)
from .fit_models.peaks import TwoGaussianModel, TwoLorModel
from .fit_models.two_dimensional import EffectiveMassModel, Gaussian2DModel
from .fit_models.wrapped import (
ConstantModel,
GaussianModel,
LinearModel,
LogisticModel,
LorentzianModel,
SineModel,
SkewedVoigtModel,
SplitLorentzianModel,
StepModel,
VoigtModel,
)
from .fit_models.x_model_mixin import XModelMixin, gaussian_convolve
from .utilities import broadcast_model, result_to_hints

NAN_POLICY = Literal["raise", "propagate", "omit"]
Expand Down
5 changes: 3 additions & 2 deletions src/arpes/plotting/fit_tool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from PySide6 import QtCore, QtGui, QtWidgets
from PySide6.QtWidgets import QLayout, QWidget

from arpes.constants import TWO_DIMENSION
from arpes.fits.utilities import result_to_hints
from arpes.plotting.qt_tool.BinningInfoWidget import BinningInfoWidget
from arpes.utilities.qt import (
Expand Down Expand Up @@ -231,7 +232,7 @@ def configure_image_widgets(self) -> None:
The 1D marginal will have a cursor and binning controls on that cursor.
"""
if len(self.data.dims) == 2: # noqa: PLR2004
if len(self.data.dims) == TWO_DIMENSION:
self.generate_marginal_for((), 0, 0, "xy", cursors=True, layout=self.content_layout)
self.generate_fit_marginal_for(
(0, 1),
Expand All @@ -243,7 +244,7 @@ def configure_image_widgets(self) -> None:
)
self.views["xy"].view.setYLink(self.views["fit"].inner_plot)

if len(self.data.dims) == 3: # noqa: PLR2004
if len(self.data.dims) == TWO_DIMENSION + 1:
self.generate_marginal_for((2,), 1, 0, "xy", cursors=True, layout=self.content_layout)
self.generate_fit_marginal_for(
(0, 1, 2),
Expand Down
3 changes: 2 additions & 1 deletion src/arpes/plotting/qt_tool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from PySide6 import QtCore, QtWidgets
from PySide6.QtWidgets import QGridLayout

from arpes.constants import TWO_DIMENSION
from arpes.utilities import normalize_to_spectrum
from arpes.utilities.qt import (
BasicHelpDialog,
Expand Down Expand Up @@ -295,7 +296,7 @@ def configure_image_widgets(self) -> None:
An additional complexity is that we also handle the cursor registration here.
"""
if len(self.data.dims) == 2: # noqa: PLR2004
if len(self.data.dims) == TWO_DIMENSION:
self.generate_marginal_for((), 1, 0, "xy", cursors=True, layout=self.content_layout)
self.generate_marginal_for(
(1,),
Expand Down
2 changes: 1 addition & 1 deletion src/arpes/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ def path_for_plot(desired_path: str | Path) -> Path:
filename = (
Path(figure_path)
/ workspace["name"]
/ datetime.datetime.now(tz=datetime.timezone.utc).date().isoformat()
/ datetime.datetime.now(tz=datetime.UTC).date().isoformat()
/ desired_path
)
filename = Path(filename).absolute()
Expand Down
4 changes: 2 additions & 2 deletions src/arpes/utilities/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def parse_single_path(path: str) -> list[SpecialPoint]:
ToDo: Shold be removed. Use ase.
"""
# first tokenize
tokens = [name for name in re.split(r"([A-Z][a-z0-9]*(?:\([0-9,\s]+\))?)", path) if name]
tokens:list[str] = [name for name in re.split(r"([A-Z][a-z0-9]*(?:\([0-9,\s]+\))?)", path) if name]

# normalize Gamma to G
tokens = [token.replace("Gamma", "G") for token in tokens]
Expand Down Expand Up @@ -147,7 +147,7 @@ def _parse_path(paths: str | list[str]) -> list[list[SpecialPoint]]:
"""
if isinstance(paths, str):
# some manual string work in order to make sure we do not split on commas inside BZ indices
idxs = []
idxs:list[int] = []
for i, p in enumerate(paths):
if p == ",":
c = Counter(paths[:i])
Expand Down
10 changes: 6 additions & 4 deletions src/arpes/utilities/conversion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import xarray as xr
from scipy.interpolate import RegularGridInterpolator

from arpes.constants import TWO_DIMENSION
from arpes.provenance import Provenance, provenance, update_provenance
from arpes.utilities import normalize_to_spectrum

Expand Down Expand Up @@ -166,9 +167,10 @@ def slice_along_path( # noqa: PLR0913
Returns:
xr.DataArray containing the interpolated data.
"""
if interpolation_points is None:
msg = "You must provide points specifying an interpolation path"
raise ValueError(msg)
assert isinstance(
interpolation_points,
np.ndarray,
), "You must provide points specifying an interpolation path"

parsed_interpolation_points = [
(
Expand Down Expand Up @@ -302,7 +304,7 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[
assert isinstance(converted_ds, xr.Dataset)

if (
axis_name in arr.dims and len(parsed_interpolation_points) == 2 # noqa: PLR2004
axis_name in arr.dims and len(parsed_interpolation_points) == TWO_DIMENSION
) and parsed_interpolation_points[1][axis_name] < parsed_interpolation_points[0][axis_name]:
# swap the sign on this axis as a convenience to the caller
converted_ds.coords[axis_name].data = -converted_ds.coords[axis_name].data
Expand Down
3 changes: 2 additions & 1 deletion src/arpes/utilities/qt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PySide6 import QtWidgets

import arpes.config
from arpes.constants import TWO_DIMENSION
from arpes.utilities.ui import CursorRegion

from .data_array_image_view import DataArrayImageView, DataArrayPlot
Expand Down Expand Up @@ -184,7 +185,7 @@ def generate_marginal_for( # noqa: PLR0913
widget.addItem(cursor, ignoreBounds=False)
self.connect_cursor(remaining_dims[0], cursor)
else:
assert len(remaining_dims) == 2 # noqa: PLR2004
assert len(remaining_dims) == TWO_DIMENSION
widget = DataArrayImageView(name=name)
widget.view.setAspectLocked(lock=False)
self.views[name] = widget
Expand Down
46 changes: 21 additions & 25 deletions src/arpes/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from collections.abc import Sequence
from functools import wraps
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeAlias, TypeVar

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -506,7 +506,7 @@ def compute_parameters() -> dict:
"""
renamed = [
{f"{prefix}_{k}": v for k, v in m_setting.items()}
for m_setting, prefix in zip(model_settings, prefixes, strict=True)
for m_setting, prefix in zip(model_settings, prefixes, strict=False)
]
return dict(itertools.chain(*[list(d.items()) for d in renamed]))

Expand Down Expand Up @@ -757,7 +757,7 @@ def kspace_tool(
overplot_bz: Callable[[Axes], None] | list[Callable[[Axes], None]] | None = None,
bounds: dict[MOMENTUM, tuple[float, float]] | None = None,
resolution: dict | None = None,
coords: dict[str, NDArray[np.float_] | xr.DataArray] | None = None,
coords: dict[Literal["kp", "kx", "ky", "kz"], NDArray[np.float_]] | None = None,
**kwargs: Incomplete,
) -> CurrentContext:
"""A utility for assigning coordinate offsets using a live momentum conversion.
Expand All @@ -777,22 +777,21 @@ def kspace_tool(
ValueError: [TODO:description]
"""
original_data = data
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)

assert isinstance(data_array, xr.DataArray)
if len(data_array.dims) > TWO_DIMENSION:
data_array = data_array.sel(eV=slice(-0.05, 0.05)).sum("eV", keep_attrs=True)
data_array.coords["eV"] = 0
assert isinstance(data, xr.DataArray)
if len(data.dims) > TWO_DIMENSION:
data = data.sel(eV=slice(-0.05, 0.05)).sum("eV", keep_attrs=True)
data.coords["eV"] = 0

if "eV" in data_array.dims:
data_array.S.transpose_to_front("eV")
data_array = data_array.copy(deep=True)
if "eV" in data.dims:
data.S.transpose_to_front("eV")
data = data.copy(deep=True)

ctx: CurrentContext = {"original_data": original_data, "data": data_array, "widgets": []}
ctx: CurrentContext = {"original_data": original_data, "data": data, "widgets": []}
arpes.config.CONFIG["CURRENT_CONTEXT"] = ctx
gs = gridspec.GridSpec(4, 3)
ax_initial = plt.subplot(gs[0:2, 0:2])
ax_converted = plt.subplot(gs[2:, 0:2])
ax_initial, ax_converted = plt.subplot(gs[0:2, 0:2]), plt.subplot(gs[2:, 0:2])

if overplot_bz is not None:
if not isinstance(overplot_bz, Sequence):
Expand All @@ -813,17 +812,14 @@ def kspace_tool(

skip_dims = {"x", "X", "y", "Y", "z", "Z", "T"}
for dim in skip_dims:
if dim in data_array.dims:
msg = f"Please provide data without the {dim} dimension"
raise ValueError(msg)
assert dim not in data.dims, f"Please provide data without the {dim} dimension"

convert_dims = ["theta", "beta", "phi", "psi"]
if "eV" not in data_array.dims:
if "eV" not in data.dims:
convert_dims += ["chi"]
if "hv" in data_array.dims:
if "hv" in data.dims:
convert_dims += ["hv"]

ang_range = (np.deg2rad(-45), np.deg2rad(45), 0.01)
default_ranges = {
"eV": [-0.05, 0.05, 0.001],
"hv": [-20, 20, 0.5],
Expand All @@ -833,12 +829,12 @@ def kspace_tool(

def update_kspace_plot() -> None:
for name, slider in sliders.items():
data_array.attrs[f"{name}_offset"] = slider.val
data.attrs[f"{name}_offset"] = slider.val

with warnings.catch_warnings():
warnings.simplefilter("ignore")
converted_view.data = convert_to_kspace(
data_array,
data,
bounds=bounds,
resolution=resolution,
coords=coords,
Expand All @@ -848,8 +844,8 @@ def update_kspace_plot() -> None:
axes = iter(widget_axes)
for convert_dim in convert_dims:
widget_ax = next(axes)
low, high, delta = default_ranges.get(convert_dim, ang_range)
init = data_array.S.lookup_offset(convert_dim)
low, high, delta = default_ranges.get(convert_dim, (np.deg2rad(-45), np.deg2rad(45), 0.01))
init = data.S.lookup_offset(convert_dim)
sliders[convert_dim] = Slider(
widget_ax,
convert_dim,
Expand Down Expand Up @@ -907,7 +903,7 @@ def apply_offsets(event: Event) -> None:
data_view = DataArrayView(ax_initial)
converted_view = DataArrayView(ax_converted)

data_view.data = data_array
data_view.data = data
update_kspace_plot()

plt.tight_layout()
Expand Down

0 comments on commit c5d820e

Please sign in to comment.