diff --git a/arpes/analysis/__init__.py b/arpes/analysis/__init__.py
index 403919df..2c847034 100644
--- a/arpes/analysis/__init__.py
+++ b/arpes/analysis/__init__.py
@@ -1,2 +1,32 @@
"""Contains common ARPES analysis routines."""
from __future__ import annotations
+
+from .band_analysis import fit_bands, fit_for_effective_mass
+from .decomposition import (
+ decomposition_along,
+ factor_analysis_along,
+ ica_along,
+ nmf_along,
+ pca_along,
+)
+from .deconvolution import deconvolve_ice, deconvolve_rl, make_psf1d
+from .derivative import curvature1d, curvature2d, d1_along_axis, d2_along_axis, minimum_gradient
+from .filters import boxcar_filter, boxcar_filter_arr, gaussian_filter, gaussian_filter_arr
+from .gap import determine_broadened_fermi_distribution, normalize_by_fermi_dirac, symmetrize
+from .general import (
+ condense,
+ fit_fermi_edge,
+ normalize_by_fermi_distribution,
+ rebin,
+ symmetrize_axis,
+)
+from .kfermi import kfermi_from_mdcs
+from .mask import apply_mask, apply_mask_to_coords, polys_to_mask, raw_poly_to_mask
+from .pocket import (
+ curves_along_pocket,
+ edcs_along_pocket,
+ pocket_parameters,
+ radial_edcs_along_pocket,
+)
+from .tarpes import normalized_relative_change, relative_change
+from .xps import approximate_core_levels
diff --git a/arpes/analysis/derivative.py b/arpes/analysis/derivative.py
index 8e57e186..c52ce095 100644
--- a/arpes/analysis/derivative.py
+++ b/arpes/analysis/derivative.py
@@ -19,8 +19,8 @@
from arpes._typing import DataType
__all__ = (
- "curvature",
- "dn_along_axis",
+ "curvature2d",
+ "curvature1d",
"d2_along_axis",
"d1_along_axis",
"minimum_gradient",
@@ -136,6 +136,7 @@ def _gradient_modulus(data: DataType, *, delta: DELTA = 1) -> xr.DataArray:
return data_copy
+@update_provenance("Maximum Curvature 1D")
def curvature1d(
arr: xr.DataArray,
dim: str = "",
@@ -184,6 +185,7 @@ def warpped_filter(arr: xr.DataArray):
return filterd_arr
+@update_provenance("Maximum Curvature 2D")
def curvature2d(
arr: xr.DataArray,
directions: tuple[str, str] = ("phi", "eV"),
diff --git a/arpes/analysis/mask.py b/arpes/analysis/mask.py
index 9930d405..93d05274 100644
--- a/arpes/analysis/mask.py
+++ b/arpes/analysis/mask.py
@@ -12,6 +12,7 @@
if TYPE_CHECKING:
import xarray as xr
from _typeshed import Incomplete
+ from numpy.typing import NDArray
from arpes._typing import DataType
diff --git a/arpes/analysis/pocket.py b/arpes/analysis/pocket.py
index db4b1de0..380cea0f 100644
--- a/arpes/analysis/pocket.py
+++ b/arpes/analysis/pocket.py
@@ -308,7 +308,7 @@ def edcs_along_pocket(
for kf, ss in zip(kfs, slices)
]
- edcs = [data.S.select_around(l, radius=select_radius, fast=True) for l in locations]
+ edcs = [data.S.select_around(_, radius=select_radius, fast=True) for _ in locations]
data_vars = {}
index = np.array(angles)
diff --git a/arpes/corrections/fermi_edge_corrections.py b/arpes/corrections/fermi_edge_corrections.py
index 231eece8..a11d8092 100644
--- a/arpes/corrections/fermi_edge_corrections.py
+++ b/arpes/corrections/fermi_edge_corrections.py
@@ -18,8 +18,8 @@
def _exclude_from_set(excluded):
- def exclude(l):
- return list(set(l).difference(excluded))
+ def exclude(_):
+ return list(set(_).difference(excluded))
return exclude
diff --git a/arpes/endstations/__init__.py b/arpes/endstations/__init__.py
index d4873ad1..06a07aa4 100644
--- a/arpes/endstations/__init__.py
+++ b/arpes/endstations/__init__.py
@@ -369,25 +369,25 @@ def postprocess_final(
data.spectrum.attrs["spectrum_type"] = spectrum_type
ls = [data, *data.S.spectra]
- for l in ls:
+ for _ in ls:
for k, key_fn in self.ATTR_TRANSFORMS.items():
- if k in l.attrs:
- transformed = key_fn(l.attrs[k])
+ if k in _.attrs:
+ transformed = key_fn(_.attrs[k])
if isinstance(transformed, dict):
- l.attrs.update(transformed)
+ _.attrs.update(transformed)
else:
- l.attrs[k] = transformed
+ _.attrs[k] = transformed
- for l in ls:
+ for _ in ls:
for k, v in self.MERGE_ATTRS.items():
- if k not in l.attrs:
- l.attrs[k] = v
+ if k not in _.attrs:
+ _.attrs[k] = v
- for l in ls:
+ for _ in ls:
for c in self.ENSURE_COORDS_EXIST:
- if c not in l.coords:
- if c in l.attrs:
- l.coords[c] = l.attrs[c]
+ if c not in _.coords:
+ if c in _.attrs:
+ _.coords[c] = _.attrs[c]
else:
warnings_msg = f"Could not assign coordinate {c} from attributes,"
warnings_msg += "assigning np.nan instead."
@@ -395,11 +395,11 @@ def postprocess_final(
warnings_msg,
stacklevel=2,
)
- l.coords[c] = np.nan
+ _.coords[c] = np.nan
- for l in ls:
- if "chi" in l.coords and "chi_offset" not in l.attrs:
- l.attrs["chi_offset"] = l.coords["chi"].item()
+ for _ in ls:
+ if "chi" in _.coords and "chi_offset" not in _.attrs:
+ _.attrs["chi_offset"] = _.coords["chi"].item()
# go and change endianness and datatypes to something reasonable
# this is done for performance reasons in momentum space conversion, primarily
diff --git a/arpes/endstations/plugin/ALG_main.py b/arpes/endstations/plugin/ALG_main.py
index 2f7176e2..57dc0805 100644
--- a/arpes/endstations/plugin/ALG_main.py
+++ b/arpes/endstations/plugin/ALG_main.py
@@ -28,7 +28,7 @@ class ALGMainChamber(HemisphericalEndstation, FITSEndstation):
]
ATTR_TRANSFORMS: ClassVar[dict] = {
- "START_T": lambda l: {"time": " ".join(l.split(" ")[1:]).lower(), "date": l.split(" ")[0]},
+ "START_T": lambda _: {"time": " ".join(_.split(" ")[1:]).lower(), "date": _.split(" ")[0]},
}
RENAME_KEYS: ClassVar[dict[str, str]] = {
diff --git a/arpes/endstations/plugin/ALG_spin_ToF.py b/arpes/endstations/plugin/ALG_spin_ToF.py
index 6d5c424a..d8d080f1 100644
--- a/arpes/endstations/plugin/ALG_spin_ToF.py
+++ b/arpes/endstations/plugin/ALG_spin_ToF.py
@@ -248,7 +248,7 @@ def load_SToF_fits(self, scan_desc: dict | None = None, **kwargs: Incomplete):
hdulist.close()
relevant_dimensions = {
- k for k in coords if k in set(itertools.chain(*[l[0] for l in data_vars.values()]))
+ k for k in coords if k in set(itertools.chain(*[_[0] for _ in data_vars.values()]))
}
relevant_coords = {k: v for k, v in coords.items() if k in relevant_dimensions}
diff --git a/arpes/endstations/plugin/ANTARES.py b/arpes/endstations/plugin/ANTARES.py
index 3c2022f3..ac10bcfd 100644
--- a/arpes/endstations/plugin/ANTARES.py
+++ b/arpes/endstations/plugin/ANTARES.py
@@ -306,8 +306,8 @@ def check_attrs(s):
pass
ls = [data, *data.S.spectra]
- for l in ls:
- check_attrs(l)
+ for _ in ls:
+ check_attrs(_)
# attempt to determine whether the energy is likely a kinetic energy
# if so, we will subtract the photon energy
diff --git a/arpes/endstations/plugin/BL10_SARPES.py b/arpes/endstations/plugin/BL10_SARPES.py
index fa8808a3..49f02c90 100644
--- a/arpes/endstations/plugin/BL10_SARPES.py
+++ b/arpes/endstations/plugin/BL10_SARPES.py
@@ -155,9 +155,9 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: SCANDESC | None = None)
data.coords[c] = np.deg2rad(data.coords[c])
for angle_attr in deg_to_rad_attrs:
- for l in ls:
- if angle_attr in l.attrs:
- l.attrs[angle_attr] = np.deg2rad(float(l.attrs[angle_attr]))
+ for _ in ls:
+ if angle_attr in _.attrs:
+ _.attrs[angle_attr] = np.deg2rad(float(_.attrs[angle_attr]))
data.attrs["alpha"] = np.pi / 2
data.attrs["psi"] = 0
@@ -169,9 +169,9 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: SCANDESC | None = None)
# .spectrum.attrs, for now just paste them over
necessary_coord_names = {"theta", "beta", "chi", "phi"}
ls = data.S.spectra
- for l in ls:
+ for _ in ls:
for cname in necessary_coord_names:
- if cname not in l.attrs and cname not in l.coords and cname in data.attrs:
- l.attrs[cname] = data.attrs[cname]
+ if cname not in _.attrs and cname not in _.coords and cname in data.attrs:
+ _.attrs[cname] = data.attrs[cname]
return super().postprocess_final(data, scan_desc)
diff --git a/arpes/endstations/plugin/HERS.py b/arpes/endstations/plugin/HERS.py
index ead70dc9..3b3ea084 100644
--- a/arpes/endstations/plugin/HERS.py
+++ b/arpes/endstations/plugin/HERS.py
@@ -81,7 +81,7 @@ def load(self, scan_desc: dict | None = None, **kwargs: Incomplete):
hdulist.close()
relevant_dimensions = {
- k for k in coords if k in set(itertools.chain(*[l[0] for l in data_vars.values()]))
+ k for k in coords if k in set(itertools.chain(*[_[0] for _ in data_vars.values()]))
}
relevant_coords = {k: v for k, v in coords.items() if k in relevant_dimensions}
diff --git a/arpes/endstations/plugin/MAESTRO.py b/arpes/endstations/plugin/MAESTRO.py
index 9e47bd3e..a721c39e 100644
--- a/arpes/endstations/plugin/MAESTRO.py
+++ b/arpes/endstations/plugin/MAESTRO.py
@@ -63,13 +63,13 @@ def fix_prebinned_coordinates(self):
def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None):
ls = [data, *data.S.spectra]
- for l in ls:
- l.attrs.update(self.ANALYZER_INFORMATION)
+ for _ in ls:
+ _.attrs.update(self.ANALYZER_INFORMATION)
- if "GRATING" in l.attrs:
- l.attrs["grating_lines_per_mm"] = {
+ if "GRATING" in _.attrs:
+ _.attrs["grating_lines_per_mm"] = {
"G201b": 600,
- }.get(l.attrs["GRATING"])
+ }.get(_.attrs["GRATING"])
return super().postprocess_final(data, scan_desc)
@@ -126,14 +126,14 @@ class MAESTROMicroARPESEndstation(MAESTROARPESEndstationBase):
}
ATTR_TRANSFORMS: ClassVar[dict] = {
- "START_T": lambda l: {
- "time": " ".join(l.split(" ")[1:]).lower(),
- "date": l.split(" ")[0],
+ "START_T": lambda _: {
+ "time": " ".join(_.split(" ")[1:]).lower(),
+ "date": _.split(" ")[0],
},
- "SF_SLITN": lambda l: {
- "slit_number": int(l.split(" ")[0]),
- "slit_shape": l.split(" ")[-1].lower(),
- "slit_width": float(l.split(" ")[2]),
+ "SF_SLITN": lambda _: {
+ "slit_number": int(_.split(" ")[0]),
+ "slit_shape": _.split(" ")[-1].lower(),
+ "slit_width": float(_.split(" ")[2]),
},
}
@@ -243,14 +243,14 @@ class MAESTRONanoARPESEndstation(MAESTROARPESEndstationBase):
}
ATTR_TRANSFORMS: ClassVar[dict] = {
- "START_T": lambda l: {
- "time": " ".join(l.split(" ")[1:]).lower(),
- "date": l.split(" ")[0],
+ "START_T": lambda _: {
+ "time": " ".join(_.split(" ")[1:]).lower(),
+ "date": _.split(" ")[0],
},
- "SF_SLITN": lambda l: {
- "slit_number": int(l.split(" ")[0]),
- "slit_shape": l.split(" ")[-1].lower(),
- "slit_width": float(l.split(" ")[2]),
+ "SF_SLITN": lambda _: {
+ "slit_number": int(_.split(" ")[0]),
+ "slit_shape": _.split(" ")[-1].lower(),
+ "slit_width": float(_.split(" ")[2]),
},
}
diff --git a/arpes/endstations/plugin/kaindl.py b/arpes/endstations/plugin/kaindl.py
index e63cec97..6dbd8dce 100644
--- a/arpes/endstations/plugin/kaindl.py
+++ b/arpes/endstations/plugin/kaindl.py
@@ -156,7 +156,7 @@ def concatenate_frames(
axis_name = lines[0].strip()
axis_name = self.RENAME_KEYS.get(axis_name, axis_name)
- values = [float(l.strip()) for l in lines[1 : len(frames) + 1]]
+ values = [float(_.strip()) for _ in lines[1 : len(frames) + 1]]
for v, f in zip(values, frames):
f.coords[axis_name] = v
@@ -239,9 +239,9 @@ def attach_attr(data, attr_name, as_name):
data.attrs[angle_attr] = np.deg2rad(float(data.attrs[angle_attr]))
ls = [data, *data.S.spectra]
- for l in ls:
- l.coords["x"] = np.nan
- l.coords["y"] = np.nan
- l.coords["z"] = np.nan
+ for _ in ls:
+ _.coords["x"] = np.nan
+ _.coords["y"] = np.nan
+ _.coords["z"] = np.nan
return super().postprocess_final(data, scan_desc)
diff --git a/arpes/endstations/plugin/merlin.py b/arpes/endstations/plugin/merlin.py
index 8de14b9c..7639530b 100644
--- a/arpes/endstations/plugin/merlin.py
+++ b/arpes/endstations/plugin/merlin.py
@@ -89,15 +89,15 @@ class BL403ARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SESEn
}
ATTR_TRANSFORMS: ClassVar[dict[str, Any]] = {
- "acquisition_mode": lambda l: l.lower(),
- "lens_mode": lambda l: {
+ "acquisition_mode": lambda _: _.lower(),
+ "lens_mode": lambda _: {
"lens_mode": None,
- "lens_mode_name": l,
+ "lens_mode_name": _,
},
"undulator_polarization": int,
- "region_name": lambda l: {
- "daq_region_name": l,
- "daq_region": l,
+ "region_name": lambda _: {
+ "daq_region_name": _,
+ "daq_region": _,
},
}
@@ -125,12 +125,12 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = N
Path(original_filename).parent / f"{internal_match.groups()[0]}_Motor_Pos.txt",
)
try:
- with open(motors_path) as f:
+ with Path(motors_path).open() as f:
lines = f.readlines()
axis_name = lines[0].strip()
axis_name = self.RENAME_KEYS.get(axis_name, axis_name)
- values = [float(l.strip()) for l in lines[1 : len(frames) + 1]]
+ values = [float(_.strip()) for _ in lines[1 : len(frames) + 1]]
for v, f in zip(values, frames):
f.coords[axis_name] = v
@@ -139,10 +139,10 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = N
for frame in frames:
# promote x, y, z to coords so they get concatted
- for l in [frame, *frame.S.spectra]:
+ for _ in [frame, *frame.S.spectra]:
for c in ["x", "y", "z"]:
- if c not in l.coords:
- l.coords[c] = l.attrs[c]
+ if c not in _.coords:
+ _.coords[c] = _.attrs[c]
return xr.concat(frames, axis_name, coords="different")
except Exception:
@@ -251,23 +251,23 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None):
"""
ls = [data, *data.S.spectra]
- for l in ls:
- if "slit_number" in l.attrs:
+ for dat in ls:
+ if "slit_number" in dat.attrs:
slit_lookup = {
1: ("straight", 0.1),
7: ("curved", 0.5),
}
- shape, width = slit_lookup.get(l.attrs["slit_number"], (None, None))
- l.attrs["slit_shape"] = shape
- l.attrs["slit_width"] = width
+ shape, width = slit_lookup.get(dat.attrs["slit_number"], (None, None))
+ dat.attrs["slit_shape"] = shape
+ dat.attrs["slit_width"] = width
- if "undulator_polarization" in l.attrs:
+ if "undulator_polarization" in dat.attrs:
phase_angle_lookup = {0: (0, 0), 2: (np.pi / 2, 0)} # LH # LV
polarization_theta, polarization_alpha = phase_angle_lookup[
- int(l.attrs["undulator_polarization"])
+ int(dat.attrs["undulator_polarization"])
]
- l.attrs["probe_polarization_theta"] = polarization_theta
- l.attrs["probe_polarization_alpha"] = polarization_alpha
+ dat.attrs["probe_polarization_theta"] = polarization_theta
+ dat.attrs["probe_polarization_alpha"] = polarization_alpha
deg_to_rad_coords = {"theta", "phi", "beta", "chi", "psi"}
deg_to_rad_attrs = {"theta", "beta", "chi", "psi", "alpha"}
@@ -277,9 +277,9 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None):
data.coords[c] = data.coords[c] * np.pi / 180
for angle_attr in deg_to_rad_attrs:
- for l in ls:
- if angle_attr in l.attrs:
- l.attrs[angle_attr] = float(l.attrs[angle_attr]) * np.pi / 180
+ for dat in ls:
+ if angle_attr in dat.attrs:
+ dat.attrs[angle_attr] = np.deg2rad(float(dat.attrs[angle_attr]))
data.attrs["alpha"] = np.pi / 2
data.attrs["psi"] = 0
@@ -287,13 +287,17 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None):
s.attrs["alpha"] = np.pi / 2
s.attrs["psi"] = 0
- # TODO: Conrad think more about why sometimes individual attrs don't make it onto
- # .spectrum.attrs, for now just paste them over
+ # TODO: Conrad think more about why sometimes individual attrs don't
+ # make it onto .spectrum.attrs, for now just paste them over
necessary_coord_names = {"theta", "beta", "chi", "phi"}
ls = data.S.spectra
- for l in ls:
+ for spectrum in ls:
for cname in necessary_coord_names:
- if cname not in l.attrs and cname not in l.coords and cname in data.attrs:
- l.attrs[cname] = data.attrs[cname]
+ if (
+ cname not in spectrum.attrs
+ and cname not in spectrum.coords
+ and cname in data.attrs
+ ):
+ spectrum.attrs[cname] = data.attrs[cname]
return super().postprocess_final(data, scan_desc)
diff --git a/arpes/experiment/__init__.py b/arpes/experiment/__init__.py
index c7a2bdf2..5aae1236 100644
--- a/arpes/experiment/__init__.py
+++ b/arpes/experiment/__init__.py
@@ -26,7 +26,7 @@
def flatten(lists):
- return chain.from_iterable([l if np.iterable(l) else [l] for l in lists])
+ return chain.from_iterable([_ if np.iterable(_) else [_] for _ in lists])
class ExperimentTreeItem:
diff --git a/arpes/fits/lmfit_html_repr.py b/arpes/fits/lmfit_html_repr.py
index a0a42061..dd419f1e 100644
--- a/arpes/fits/lmfit_html_repr.py
+++ b/arpes/fits/lmfit_html_repr.py
@@ -30,12 +30,12 @@ def repr_multiline_ModelResult(self: model.Model, **kwargs: Incomplete) -> str:
[(" " * 4) + c._repr_multiline_text_() for c in self.components],
),
parameters="\n".join(
- f" {l}" for l in self.params._repr_multiline_text_(**kwargs).split("\n")
+ f" {l_item}" for l_item in self.params._repr_multiline_text_(**kwargs).split("\n")
),
)
-def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str:
+def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str: # noqa: N802
"""Provides a better Jupyter representation of an `lmfit.ModelResult` instance."""
template = """
@@ -51,7 +51,7 @@ def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str:
)
-def repr_html_Model(self: Incomplete) -> str:
+def repr_html_Model(self: Incomplete) -> str: # noqa: N802
"""Better Jupyter representation of `lmfit.Model` instances."""
template = """
@@ -61,7 +61,7 @@ def repr_html_Model(self: Incomplete) -> str:
return template.format(name=self.name)
-def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str:
+def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str: # noqa: N802
"""Provides a text-based multiline representation used in Qt based interactive tools."""
return self.name
@@ -70,7 +70,7 @@ def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str:
SKIP_ON_SHORT = {"min", "max", "vary", "expr", "brute_step"}
-def repr_html_Parameters(self: Incomplete, *, short: bool = False) -> str:
+def repr_html_Parameters(self: Incomplete, *, short: bool = False) -> str: # noqa: N802
"""HTML representation for `lmfit.Parameters` instances."""
keys = sorted(self.keys())
template = """
diff --git a/arpes/fits/utilities.py b/arpes/fits/utilities.py
index 06c69d86..275c91f6 100644
--- a/arpes/fits/utilities.py
+++ b/arpes/fits/utilities.py
@@ -18,7 +18,6 @@
import dill
import numpy as np
import xarray as xr
-from packaging import version
from tqdm.notebook import tqdm
import arpes.fits.fit_models
@@ -40,34 +39,24 @@
TypeIterable = list[type] | tuple[type]
-XARRAY_REQUIRES_VALUES_WRAPPING = version.parse(xr.__version__) > version.parse("0.10.0")
-
-
-def wrap_for_xarray_values_unpacking(item):
- """This is a shim for https://github.com/pydata/xarray/issues/2097."""
- if XARRAY_REQUIRES_VALUES_WRAPPING:
- return np.array(item, dtype=object)
-
- return item
-
def result_to_hints(
- m: lmfit.model.ModelResult | None,
+ model_result: lmfit.model.ModelResult | None,
defaults=None,
) -> dict[str, dict[str, Any]] | None:
"""Turns an `lmfit.model.ModelResult` into a dictionary with initial guesses.
Args:
- m: The model result to extract parameters from
- defaults: Returned if `m` is None, useful for cell re-evaluation in Jupyter
+ model_result: The model result to extract parameters from
+ defaults: Returned if `model_result` is None, useful for cell re-evaluation in Jupyter
Returns:
A dict containing parameter specifications in key-value rathan than `lmfit.Parameter`
format, as you might pass as `params=` to PyARPES fitting code.
"""
- if m is None:
+ if model_result is None:
return defaults
- return {k: {"value": m.params[k].value} for k in m.params}
+ return {k: {"value": model_result.params[k].value} for k in model_result.params}
def parse_model(model):
@@ -98,7 +87,7 @@ def parse_model(model):
special = set(pad_all)
- def read_token(token):
+ def read_token(token: str) -> str | float:
if token in special:
return token
try:
@@ -127,7 +116,7 @@ def broadcast_model(
*,
progress: bool = True,
safe: bool = False,
- trace: Callable = None, # type: ignore # noqa: RUF013
+ trace: Callable = None, # noqa: RUF013
) -> xr.Dataset:
"""Perform a fit across a number of dimensions.
@@ -174,7 +163,7 @@ def broadcast_model(
n_fits = np.prod(np.array(list(template.S.dshape.values())))
if parallelize is None:
- parallelize = bool(n_fits > 20)
+ parallelize = bool(n_fits > 20) # noqa: PLR2004
trace("Copying residual")
residual = data_array.copy(deep=True)
@@ -189,7 +178,16 @@ def broadcast_model(
wrap_progress = tqdm
else:
- def wrap_progress(x: Iterable[int], *_, **__) -> Iterable[int]:
+ def wrap_progress(x: Iterable[int], **__: str | float) -> Iterable[int]:
+ """Fake of tqdm.notebook.tqdm.
+
+ Args:
+ x (Iterable[int]): [TODO:description]
+ __: its a dummy parameter, which is not used.
+
+ Returns:
+ Same iterable.
+ """
return x
serialize = parallelize
@@ -231,7 +229,7 @@ def wrap_progress(x: Iterable[int], *_, **__) -> Iterable[int]:
if serialize:
trace("Deserializing...")
- def unwrap(result_data):
+ def unwrap(result_data) -> object:
# using the lmfit deserialization and serialization seems slower than double pickling
# with dill
return dill.loads(result_data)
@@ -240,7 +238,7 @@ def unwrap(result_data):
trace("Finished running fits Collating")
for fit_result, fit_residual, coords in exe_results:
- template.loc[coords] = wrap_for_xarray_values_unpacking(fit_result)
+ template.loc[coords] = np.array(fit_result)
residual.loc[coords] = fit_residual
trace("Bundling into dataset")
diff --git a/arpes/fits/zones.py b/arpes/fits/zones.py
index 172ead27..be720e1e 100644
--- a/arpes/fits/zones.py
+++ b/arpes/fits/zones.py
@@ -29,15 +29,24 @@
if TYPE_CHECKING:
from numpy.typing import NDArray
+ from arpes._typing import DataType
+
def k_points_residual(
- paramters,
- coords_dataset,
- high_symmetry_points,
+ coords_dataset: DataType,
dimensionality: int = 2,
) -> NDArray[np.float_]:
+ """[TODO:summary].
+
+ Args:
+ coords_dataset: [TODO:description]
+ dimensionality: [TODO:description]
+
+ Returns:
+ [TODO:description]
+ """
momentum_coordinates = convert_coordinates(coords_dataset)
- if dimensionality == 2:
+ if dimensionality == 2: # noqa: PLR2004
return np.asarray(
[
np.diagonal(momentum_coordinates.kx.values),
diff --git a/arpes/laue/__init__.py b/arpes/laue/__init__.py
index 38415d74..7c60090d 100644
--- a/arpes/laue/__init__.py
+++ b/arpes/laue/__init__.py
@@ -19,7 +19,7 @@
from pathlib import Path
import numpy as np
-import xarray
+import xarray as xr
from arpes.provenance import provenance_from_file
@@ -34,7 +34,10 @@
)
-def load_laue(path: Path | str):
+__all__ = ("load_laue",)
+
+
+def load_laue(path: Path | str) -> xr.DataArray:
"""Loads NorthStart Laue backscattering data."""
if isinstance(path, str):
path = Path(path)
@@ -45,7 +48,7 @@ def load_laue(path: Path | str):
table = np.fromstring(table, dtype=np.uint16).reshape(256, 256)
header = np.fromstring(header, dtype=northstar_62_69_dtype).item()
- arr = xarray.DataArray(
+ arr = xr.DataArray(
table,
coords={"x": np.array(range(256)), "y": np.array(range(256))},
dims=[
diff --git a/arpes/plotting/__init__.py b/arpes/plotting/__init__.py
index 1076cd5e..c65df890 100644
--- a/arpes/plotting/__init__.py
+++ b/arpes/plotting/__init__.py
@@ -1,2 +1,40 @@
"""Standard plotting routines and utility code for ARPES analyses."""
-from __future__ import annotations
+from __future__ import annotations # noqa: I001
+
+
+"""
+from .annotations import annotate_cuts, annotate_experimental_conditions, annotate_point
+from .band_tool import BandTool
+from .bands import plot_with_bands
+from .basic import make_reference_plots
+from .curvature_tool import CurvatureTool
+from .comparison_tool import compare
+from .dispersion import (
+ cut_dispersion_plot,
+ fancy_dispersion,
+ hv_reference_scan,
+ labeled_fermi_surface,
+ plot_dispersion,
+ reference_scan_fermi_surface,
+ scan_var_reference_plot,
+)
+from .dos import plot_core_levels, plot_dos
+from .dyn_tool import DynamicTool, dyn
+from .fermi_edge import fermi_edge, plot_fit
+from .fermi_surface import fermi_surface_slices, magnify_circular_regions_plot
+from .fit_inspection_tool import FitCheckTool
+from .interactive import ImageTool
+from .mask_tool import mask
+from .movie import plot_movie
+from .parameter import plot_parameter
+from .path_tool import path_tool
+from .spatial import plot_spatial_reference, reference_scan_spatial
+from .spin import spin_polarized_spectrum, spin_colored_spectrum, spin_difference_spectrum
+from .qt_tool import qt_tool
+from .qt_ktool import ktool
+from .utils import (
+ savefig,
+ remove_colorbars,
+ fancy_labels,
+)
+ """
diff --git a/arpes/plotting/bands.py b/arpes/plotting/bands.py
index 78f5a07f..8a7d8900 100644
--- a/arpes/plotting/bands.py
+++ b/arpes/plotting/bands.py
@@ -13,11 +13,10 @@
from pathlib import Path
from _typeshed import Incomplete
+ from build.lib.arpes.typing import DataType
from matplotlib.axes import Axes
from matplotlib.colors import Normalize
- from build.lib.arpes.typing import DataType
-
__all__ = ("plot_with_bands",)
diff --git a/arpes/plotting/bz_tool/__init__.py b/arpes/plotting/bz_tool/__init__.py
index d478ba94..72d2534b 100644
--- a/arpes/plotting/bz_tool/__init__.py
+++ b/arpes/plotting/bz_tool/__init__.py
@@ -18,6 +18,7 @@
from arpes.utilities.ui import combo_box, horizontal, tabs
from .CoordinateOffsetWidget import CoordinateOffsetWidget
+from .RangeOrSingleValueWidget import RangeOrSingleValueWidget
__all__ = [
"bz_tool",
diff --git a/arpes/plotting/curvature_tool.py b/arpes/plotting/curvature_tool.py
index 59cc010c..339c8195 100644
--- a/arpes/plotting/curvature_tool.py
+++ b/arpes/plotting/curvature_tool.py
@@ -39,7 +39,6 @@ def tool_handler(self, doc):
default_palette = self.default_palette
- x_coords, y_coords = self.arr.coords[self.arr.dims[1]], self.arr.coords[self.arr.dims[0]]
self.app_context.update(
{
"data": self.arr,
@@ -112,7 +111,6 @@ def tool_handler(self, doc):
figures["curvature"].yaxis.major_label_text_font_size = "0pt"
- # TODO: add support for color mapper
plots["d2"] = figures["d2"].image(
[self.arr.values],
x=data_range["x"][0],
diff --git a/arpes/plotting/fit_tool/__init__.py b/arpes/plotting/fit_tool/__init__.py
index a4edfba8..d4feba13 100644
--- a/arpes/plotting/fit_tool/__init__.py
+++ b/arpes/plotting/fit_tool/__init__.py
@@ -273,7 +273,7 @@ def generate_fit_marginal_for(
if layout is None:
layout = self._layout
- remaining_dims = [l for l in list(range(len(self.data.dims))) if l not in dimensions]
+ remaining_dims = [_ for _ in list(range(len(self.data.dims))) if _ not in dimensions]
# for now, we only allow a single fit dimension
widget = FitInspectionPlot(name=name, root=weakref.ref(self), orientation=orientation)
@@ -390,9 +390,9 @@ def safe_slice(vlow, vhigh, axis=0):
for_plot = for_plot.mean(list(select_coord.keys()))
cursors = [
- l
- for l in reactive.view.getPlotItem().items
- if isinstance(l, CursorRegion)
+ _
+ for _ in reactive.view.getPlotItem().items
+ if isinstance(_, CursorRegion)
]
reactive.view.clear()
for c in cursors:
diff --git a/arpes/plotting/qt_ktool/__init__.py b/arpes/plotting/qt_ktool/__init__.py
index c43451d0..0eea42e1 100644
--- a/arpes/plotting/qt_ktool/__init__.py
+++ b/arpes/plotting/qt_ktool/__init__.py
@@ -40,7 +40,7 @@ class KTool(SimpleApp):
DEFAULT_COLORMAP = "viridis"
- def __init__(self, apply_offsets=True, zone=None, **kwargs: Incomplete) -> None:
+ def __init__(self, apply_offsets: bool = True, zone=None, **kwargs: Incomplete) -> None:
"""Set attributes to safe defaults and unwrap the Brillouin zone definition."""
super().__init__()
diff --git a/arpes/plotting/qt_tool/__init__.py b/arpes/plotting/qt_tool/__init__.py
index 1a64f4c9..0ec61173 100644
--- a/arpes/plotting/qt_tool/__init__.py
+++ b/arpes/plotting/qt_tool/__init__.py
@@ -398,9 +398,9 @@ def safe_slice(vlow, vhigh, axis=0):
for_plot = for_plot.mean(list(select_coord.keys()))
cursors = [
- l
- for l in reactive.view.getPlotItem().items
- if isinstance(l, CursorRegion)
+ _
+ for _ in reactive.view.getPlotItem().items
+ if isinstance(_, CursorRegion)
]
reactive.view.clear()
for c in cursors:
diff --git a/arpes/utilities/qt/app.py b/arpes/utilities/qt/app.py
index b451971f..a591f705 100644
--- a/arpes/utilities/qt/app.py
+++ b/arpes/utilities/qt/app.py
@@ -146,7 +146,7 @@ def generate_marginal_for(
if layout is None:
layout = self._layout
- remaining_dims = [l for l in list(range(len(self.data.dims))) if l not in dimensions]
+ remaining_dims = [dim for dim in list(range(len(self.data.dims))) if dim not in dimensions]
if len(remaining_dims) == 1:
widget = DataArrayPlot(name=name, root=weakref.ref(self), orientation=orientation)
diff --git a/arpes/workflow.py b/arpes/workflow.py
index 5f3ebceb..10d8aa9c 100644
--- a/arpes/workflow.py
+++ b/arpes/workflow.py
@@ -127,25 +127,25 @@ def _read_pickled(self, name: str, default=None):
except FileNotFoundError:
return default
- def _write_pickled(self, name: str, value):
- with open(str(self.path / f"{name}.pickle"), "wb") as f:
+ def _write_pickled(self, name: str, value) -> None:
+ with Path(self.path / f"{name}.pickle").open("wb") as f:
dill.dump(value, f)
@property
- def publishers(self):
+ def publishers(self) -> object:
return self._read_pickled("publishers", defaultdict(list))
@publishers.setter
- def publishers(self, new_publishers):
+ def publishers(self, new_publishers) -> None:
assert isinstance(new_publishers, dict)
self._write_pickled("publishers", new_publishers)
@property
- def consumers(self):
+ def consumers(self) -> object:
return self._read_pickled("consumers", defaultdict(list))
@consumers.setter
- def consumers(self, new_consumers):
+ def consumers(self, new_consumers) -> None:
assert isinstance(new_consumers, dict)
self._write_pickled("consumers", new_consumers)
@@ -181,7 +181,7 @@ def publish(self, key, data) -> None:
self.summarize_consumers(key=key)
- def consume(self, key: Hashable, *, subscribe: bool = True):
+ def consume(self, key: str, *, subscribe: bool = True):
if subscribe:
context = get_running_context()
consumers = self.consumers
diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py
index cb07cd4f..fd9171d7 100644
--- a/arpes/xarray_extensions.py
+++ b/arpes/xarray_extensions.py
@@ -77,7 +77,7 @@
from arpes.utilities.xarray import unwrap_xarray_dict, unwrap_xarray_item
if TYPE_CHECKING:
- from collections.abc import Callable, Generator, Hashable, Iterator
+ from collections.abc import Callable, Generator, Hashable, Iterator, Sequence
from pathlib import Path
import pandas as pd
@@ -378,7 +378,7 @@ def select_around_data(
stored in the dataarray kFs. Then we could select momentum integrated EDCs in a small
window around the fermi momentum for each temperature by using
- >>> edcs_at_fermi_momentum = full_data.S.select_around_data({'kp': kFs}, radius={'kp': 0.04}, fast=True) # doctest: +SKIP
+ >>> edcs = full_data.S.select_around_data({'kp': kFs}, radius={'kp': 0.04}, fast=True)
The resulting data will be EDCs for each T, in a region of radius 0.04 inverse angstroms
around the Fermi momentum.
@@ -523,14 +523,11 @@ def select_around(
Returns:
The binned selection around the desired point or points.
"""
- if isinstance(self._obj, xr.Dataset):
- msg = "Cannot use select_around on Datasets only DataArrays!"
- raise TypeError(msg)
-
- if mode not in {"sum", "mean"}:
- msg = "mode parameter should be either sum or mean."
- raise ValueError(msg)
-
+ assert isinstance(
+ self._obj,
+ xr.Dataset,
+ ), "Cannot use select_around on Datasets only DataArrays!"
+ assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean."
if isinstance(point, tuple | list):
warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2)
point = dict(zip(point, self._obj.dims, strict=True))
@@ -620,8 +617,6 @@ def _calculate_symmetry_points(
# if it is projected, we need to calculate its projected coordinates
"""[TODO:summary].
- [TODO:description]
-
Args:
symmetry_points: [TODO:description]
projection_distance: [TODO:description]
@@ -808,23 +803,29 @@ def original_parent_scan_name(self) -> str:
first_modification = history[-3]
df: pd.DataFrame = self._obj.attrs["df"] # "df" means DataFrame of pandas
return df[df.id == first_modification["parent_id"]].index[0]
- except:
+ except KeyError:
pass
return ""
@property
- def scan_row(self):
- df: pd.DataFrame = self._obj.attrs["df"]
- sdf = df[df.patl == self._obj.attrs["file"]]
- return next(iter(sdf.iterrows()))
+ def scan_row(self) -> Sequence[int]:
+ try:
+ df: pd.DataFrame = self._obj.attrs["df"]
+ sdf = df[df.patl == self._obj.attrs["file"]]
+ return next(iter(sdf.iterrows()))
+ except KeyError:
+ return []
@property
- def df_index(self):
+ def df_index(self) -> int | None:
return self.scan_row[0]
@property
def df_after(self):
- return self._obj.attrs["df"][self._obj.attrs["df"].index > self.df_index]
+ try:
+ return self._obj.attrs["df"][self._obj.attrs["df"].index > self.df_index]
+ except KeyError:
+ return None
def df_until_type(
self,
@@ -1866,7 +1867,7 @@ def _repr_html_(self) -> str:
try:
name = self.df_index
- except:
+ except IndexError:
if "id" in self._obj.attrs:
name = "ID: " + str(self._obj.attrs["id"])[:9] + "..."
else:
@@ -1994,7 +1995,7 @@ def show(self, *, detached: bool = False, **kwargs: Incomplete) -> None:
def show_d2(self, **kwargs: Incomplete) -> None:
"""Opens the Bokeh based second derivative image tool."""
- from arpes.plotting.all import CurvatureTool
+ from arpes.plotting.curvature_tool import CurvatureTool
curve_tool = CurvatureTool(**kwargs)
return curve_tool.make_tool(self._obj)
diff --git a/tests/conftest.py b/tests/conftest.py
index 3dec17f0..9b4b505b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,6 @@
"""Mocks the analysis environment and provides data fixutres for tests."""
from __future__ import annotations
-import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, TypedDict
@@ -79,7 +78,7 @@ def load(path: str) -> xr.DataArray | xr.Dataset:
pieces = path.split("/")
set_workspace(pieces[0])
return cache_loader.load_test_scan(
- os.path.join(*pieces),
+ str(Path(path)),
location=SCAN_FIXTURE_LOCATIONS[path],
)
@@ -91,4 +90,4 @@ def load(path: str) -> xr.DataArray | xr.Dataset:
arpes.config.load_plugins()
yield sandbox
arpes.config.CONFIG["WORKSPACE"] = None
- arpes.endstations._ENDSTATION_ALIASES = {}
+ arpes.endstations._ENDSTATION_ALIASES = {} # noqa: SLF001
diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py
index 44935ace..e7a68b84 100644
--- a/tests/test_basic_data_loading.py
+++ b/tests/test_basic_data_loading.py
@@ -11,10 +11,12 @@
from arpes.utilities.conversion import convert_to_kspace
if TYPE_CHECKING:
+ from collections.abc import Iterable
+
from _typeshed import Incomplete
-def pytest_generate_tests(metafunc: Incomplete):
+def pytest_generate_tests(metafunc: Incomplete) -> Incomplete:
"""[TODO:summary].
[TODO:description]
@@ -641,7 +643,7 @@ def test_load_file_and_basic_attributes(
assert isinstance(data, xr.Dataset)
# assert basic dataset attributes
- for attr in ["location"]: # TODO: add spectrum type requirement
+ for attr in ["location"]:
assert attr in data.attrs
# assert that all necessary coordinates are present
@@ -679,7 +681,7 @@ def test_load_file_and_basic_attributes(
assert k
assert pytest.approx(data.coords[k].item(), 1e-3) == v
- def safefirst(x):
+ def safefirst(x: Iterable[float]) -> float:
with contextlib.suppress(TypeError, IndexError):
return x[0]
diff --git a/tests/test_direct_and_example_data_loading.py b/tests/test_direct_and_example_data_loading.py
index e272f2bb..46b56cb9 100644
--- a/tests/test_direct_and_example_data_loading.py
+++ b/tests/test_direct_and_example_data_loading.py
@@ -11,10 +11,14 @@
from arpes.io import load_data, load_example_data
if TYPE_CHECKING:
- from _typeshed import Incomplete
+ from collections.abc import Generator
+ from .conftest import Sandbox
-def test_load_data(sandbox_configuration: Incomplete) -> None:
+
+def test_load_data(
+ sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001
+) -> None:
"""[TODO:summary].
[TODO:description]
@@ -32,7 +36,9 @@ def test_load_data(sandbox_configuration: Incomplete) -> None:
assert data.spectrum.shape == (240, 240)
-def test_load_data_with_plugin_specified(sandbox_configuration: Incomplete) -> None:
+def test_load_data_with_plugin_specified(
+ sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001
+) -> None:
"""[TODO:summary].
[TODO:description]
@@ -52,7 +58,9 @@ def test_load_data_with_plugin_specified(sandbox_configuration: Incomplete) -> N
assert np.all(data.spectrum.values == directly_specified_data.spectrum.values)
-def test_load_example_data(sandbox_configuration: Incomplete) -> None:
+def test_load_example_data(
+ sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001
+) -> None:
"""[TODO:summary].
[TODO:description]
diff --git a/tests/test_time_configuration.py b/tests/test_time_configuration.py
index faf6d018..590f47b2 100644
--- a/tests/test_time_configuration.py
+++ b/tests/test_time_configuration.py
@@ -1,10 +1,20 @@
"""test for time configuration."""
+from __future__ import annotations
+
import os.path
+from typing import TYPE_CHECKING
import arpes.config
+if TYPE_CHECKING:
+ from collections.abc import Generator
+
+ from .conftest import Sandbox
+
-def test_patched_config(sandbox_configuration) -> None:
+def test_patched_config(
+ sandbox_configuration: Generator[Sandbox, None, None],
+) -> None:
"""[TODO:summary].
[TODO:description]
@@ -22,7 +32,9 @@ def test_patched_config(sandbox_configuration) -> None:
assert str(arpes.config.CONFIG["WORKSPACE"]["path"]).split(os.sep)[-2:] == ["datasets", "basic"]
-def test_patched_config_no_workspace(sandbox_configuration) -> None:
+def test_patched_config_no_workspace(
+ sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001
+) -> None:
"""[TODO:summary].
[TODO:description]
diff --git a/tests/utils.py b/tests/utils.py
index 70c0743d..dfbf0b82 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -26,11 +26,9 @@ def path_to_datasets() -> Path:
class CachingDataLoader:
cache: dict[str, xr.Dataset] = field(default_factory=dict)
- def load_test_scan(self, example_name: str | Path, **kwargs: Incomplete) -> xr.Dataset:
+ def load_test_scan(self, example_name: str, **kwargs: Incomplete) -> xr.Dataset:
"""[TODO:summary].
- [TODO:description]
-
Args:
example_name ([TODO:type]): [TODO:description]
kwargs: Pass to load_data function
@@ -49,7 +47,7 @@ def load_test_scan(self, example_name: str | Path, **kwargs: Incomplete) -> xr.D
raise ValueError(msg)
data = load_data(str(path_to_data.absolute()), **kwargs)
- self.cache[example_name] = data
+ self.cache[str(example_name)] = data
return data.copy(deep=True)