diff --git a/arpes/analysis/__init__.py b/arpes/analysis/__init__.py index 403919df..2c847034 100644 --- a/arpes/analysis/__init__.py +++ b/arpes/analysis/__init__.py @@ -1,2 +1,32 @@ """Contains common ARPES analysis routines.""" from __future__ import annotations + +from .band_analysis import fit_bands, fit_for_effective_mass +from .decomposition import ( + decomposition_along, + factor_analysis_along, + ica_along, + nmf_along, + pca_along, +) +from .deconvolution import deconvolve_ice, deconvolve_rl, make_psf1d +from .derivative import curvature1d, curvature2d, d1_along_axis, d2_along_axis, minimum_gradient +from .filters import boxcar_filter, boxcar_filter_arr, gaussian_filter, gaussian_filter_arr +from .gap import determine_broadened_fermi_distribution, normalize_by_fermi_dirac, symmetrize +from .general import ( + condense, + fit_fermi_edge, + normalize_by_fermi_distribution, + rebin, + symmetrize_axis, +) +from .kfermi import kfermi_from_mdcs +from .mask import apply_mask, apply_mask_to_coords, polys_to_mask, raw_poly_to_mask +from .pocket import ( + curves_along_pocket, + edcs_along_pocket, + pocket_parameters, + radial_edcs_along_pocket, +) +from .tarpes import normalized_relative_change, relative_change +from .xps import approximate_core_levels diff --git a/arpes/analysis/derivative.py b/arpes/analysis/derivative.py index 8e57e186..c52ce095 100644 --- a/arpes/analysis/derivative.py +++ b/arpes/analysis/derivative.py @@ -19,8 +19,8 @@ from arpes._typing import DataType __all__ = ( - "curvature", - "dn_along_axis", + "curvature2d", + "curvature1d", "d2_along_axis", "d1_along_axis", "minimum_gradient", @@ -136,6 +136,7 @@ def _gradient_modulus(data: DataType, *, delta: DELTA = 1) -> xr.DataArray: return data_copy +@update_provenance("Maximum Curvature 1D") def curvature1d( arr: xr.DataArray, dim: str = "", @@ -184,6 +185,7 @@ def warpped_filter(arr: xr.DataArray): return filterd_arr +@update_provenance("Maximum Curvature 2D") def curvature2d( arr: xr.DataArray, directions: tuple[str, str] = ("phi", "eV"), diff --git a/arpes/analysis/mask.py b/arpes/analysis/mask.py index 9930d405..93d05274 100644 --- a/arpes/analysis/mask.py +++ b/arpes/analysis/mask.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: import xarray as xr from _typeshed import Incomplete + from numpy.typing import NDArray from arpes._typing import DataType diff --git a/arpes/analysis/pocket.py b/arpes/analysis/pocket.py index db4b1de0..380cea0f 100644 --- a/arpes/analysis/pocket.py +++ b/arpes/analysis/pocket.py @@ -308,7 +308,7 @@ def edcs_along_pocket( for kf, ss in zip(kfs, slices) ] - edcs = [data.S.select_around(l, radius=select_radius, fast=True) for l in locations] + edcs = [data.S.select_around(_, radius=select_radius, fast=True) for _ in locations] data_vars = {} index = np.array(angles) diff --git a/arpes/corrections/fermi_edge_corrections.py b/arpes/corrections/fermi_edge_corrections.py index 231eece8..a11d8092 100644 --- a/arpes/corrections/fermi_edge_corrections.py +++ b/arpes/corrections/fermi_edge_corrections.py @@ -18,8 +18,8 @@ def _exclude_from_set(excluded): - def exclude(l): - return list(set(l).difference(excluded)) + def exclude(_): + return list(set(_).difference(excluded)) return exclude diff --git a/arpes/endstations/__init__.py b/arpes/endstations/__init__.py index d4873ad1..06a07aa4 100644 --- a/arpes/endstations/__init__.py +++ b/arpes/endstations/__init__.py @@ -369,25 +369,25 @@ def postprocess_final( data.spectrum.attrs["spectrum_type"] = spectrum_type ls = [data, *data.S.spectra] - for l in ls: + for _ in ls: for k, key_fn in self.ATTR_TRANSFORMS.items(): - if k in l.attrs: - transformed = key_fn(l.attrs[k]) + if k in _.attrs: + transformed = key_fn(_.attrs[k]) if isinstance(transformed, dict): - l.attrs.update(transformed) + _.attrs.update(transformed) else: - l.attrs[k] = transformed + _.attrs[k] = transformed - for l in ls: + for _ in ls: for k, v in self.MERGE_ATTRS.items(): - if k not in l.attrs: - l.attrs[k] = v + if k not in _.attrs: + _.attrs[k] = v - for l in ls: + for _ in ls: for c in self.ENSURE_COORDS_EXIST: - if c not in l.coords: - if c in l.attrs: - l.coords[c] = l.attrs[c] + if c not in _.coords: + if c in _.attrs: + _.coords[c] = _.attrs[c] else: warnings_msg = f"Could not assign coordinate {c} from attributes," warnings_msg += "assigning np.nan instead." @@ -395,11 +395,11 @@ def postprocess_final( warnings_msg, stacklevel=2, ) - l.coords[c] = np.nan + _.coords[c] = np.nan - for l in ls: - if "chi" in l.coords and "chi_offset" not in l.attrs: - l.attrs["chi_offset"] = l.coords["chi"].item() + for _ in ls: + if "chi" in _.coords and "chi_offset" not in _.attrs: + _.attrs["chi_offset"] = _.coords["chi"].item() # go and change endianness and datatypes to something reasonable # this is done for performance reasons in momentum space conversion, primarily diff --git a/arpes/endstations/plugin/ALG_main.py b/arpes/endstations/plugin/ALG_main.py index 2f7176e2..57dc0805 100644 --- a/arpes/endstations/plugin/ALG_main.py +++ b/arpes/endstations/plugin/ALG_main.py @@ -28,7 +28,7 @@ class ALGMainChamber(HemisphericalEndstation, FITSEndstation): ] ATTR_TRANSFORMS: ClassVar[dict] = { - "START_T": lambda l: {"time": " ".join(l.split(" ")[1:]).lower(), "date": l.split(" ")[0]}, + "START_T": lambda _: {"time": " ".join(_.split(" ")[1:]).lower(), "date": _.split(" ")[0]}, } RENAME_KEYS: ClassVar[dict[str, str]] = { diff --git a/arpes/endstations/plugin/ALG_spin_ToF.py b/arpes/endstations/plugin/ALG_spin_ToF.py index 6d5c424a..d8d080f1 100644 --- a/arpes/endstations/plugin/ALG_spin_ToF.py +++ b/arpes/endstations/plugin/ALG_spin_ToF.py @@ -248,7 +248,7 @@ def load_SToF_fits(self, scan_desc: dict | None = None, **kwargs: Incomplete): hdulist.close() relevant_dimensions = { - k for k in coords if k in set(itertools.chain(*[l[0] for l in data_vars.values()])) + k for k in coords if k in set(itertools.chain(*[_[0] for _ in data_vars.values()])) } relevant_coords = {k: v for k, v in coords.items() if k in relevant_dimensions} diff --git a/arpes/endstations/plugin/ANTARES.py b/arpes/endstations/plugin/ANTARES.py index 3c2022f3..ac10bcfd 100644 --- a/arpes/endstations/plugin/ANTARES.py +++ b/arpes/endstations/plugin/ANTARES.py @@ -306,8 +306,8 @@ def check_attrs(s): pass ls = [data, *data.S.spectra] - for l in ls: - check_attrs(l) + for _ in ls: + check_attrs(_) # attempt to determine whether the energy is likely a kinetic energy # if so, we will subtract the photon energy diff --git a/arpes/endstations/plugin/BL10_SARPES.py b/arpes/endstations/plugin/BL10_SARPES.py index fa8808a3..49f02c90 100644 --- a/arpes/endstations/plugin/BL10_SARPES.py +++ b/arpes/endstations/plugin/BL10_SARPES.py @@ -155,9 +155,9 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: SCANDESC | None = None) data.coords[c] = np.deg2rad(data.coords[c]) for angle_attr in deg_to_rad_attrs: - for l in ls: - if angle_attr in l.attrs: - l.attrs[angle_attr] = np.deg2rad(float(l.attrs[angle_attr])) + for _ in ls: + if angle_attr in _.attrs: + _.attrs[angle_attr] = np.deg2rad(float(_.attrs[angle_attr])) data.attrs["alpha"] = np.pi / 2 data.attrs["psi"] = 0 @@ -169,9 +169,9 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: SCANDESC | None = None) # .spectrum.attrs, for now just paste them over necessary_coord_names = {"theta", "beta", "chi", "phi"} ls = data.S.spectra - for l in ls: + for _ in ls: for cname in necessary_coord_names: - if cname not in l.attrs and cname not in l.coords and cname in data.attrs: - l.attrs[cname] = data.attrs[cname] + if cname not in _.attrs and cname not in _.coords and cname in data.attrs: + _.attrs[cname] = data.attrs[cname] return super().postprocess_final(data, scan_desc) diff --git a/arpes/endstations/plugin/HERS.py b/arpes/endstations/plugin/HERS.py index ead70dc9..3b3ea084 100644 --- a/arpes/endstations/plugin/HERS.py +++ b/arpes/endstations/plugin/HERS.py @@ -81,7 +81,7 @@ def load(self, scan_desc: dict | None = None, **kwargs: Incomplete): hdulist.close() relevant_dimensions = { - k for k in coords if k in set(itertools.chain(*[l[0] for l in data_vars.values()])) + k for k in coords if k in set(itertools.chain(*[_[0] for _ in data_vars.values()])) } relevant_coords = {k: v for k, v in coords.items() if k in relevant_dimensions} diff --git a/arpes/endstations/plugin/MAESTRO.py b/arpes/endstations/plugin/MAESTRO.py index 9e47bd3e..a721c39e 100644 --- a/arpes/endstations/plugin/MAESTRO.py +++ b/arpes/endstations/plugin/MAESTRO.py @@ -63,13 +63,13 @@ def fix_prebinned_coordinates(self): def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None): ls = [data, *data.S.spectra] - for l in ls: - l.attrs.update(self.ANALYZER_INFORMATION) + for _ in ls: + _.attrs.update(self.ANALYZER_INFORMATION) - if "GRATING" in l.attrs: - l.attrs["grating_lines_per_mm"] = { + if "GRATING" in _.attrs: + _.attrs["grating_lines_per_mm"] = { "G201b": 600, - }.get(l.attrs["GRATING"]) + }.get(_.attrs["GRATING"]) return super().postprocess_final(data, scan_desc) @@ -126,14 +126,14 @@ class MAESTROMicroARPESEndstation(MAESTROARPESEndstationBase): } ATTR_TRANSFORMS: ClassVar[dict] = { - "START_T": lambda l: { - "time": " ".join(l.split(" ")[1:]).lower(), - "date": l.split(" ")[0], + "START_T": lambda _: { + "time": " ".join(_.split(" ")[1:]).lower(), + "date": _.split(" ")[0], }, - "SF_SLITN": lambda l: { - "slit_number": int(l.split(" ")[0]), - "slit_shape": l.split(" ")[-1].lower(), - "slit_width": float(l.split(" ")[2]), + "SF_SLITN": lambda _: { + "slit_number": int(_.split(" ")[0]), + "slit_shape": _.split(" ")[-1].lower(), + "slit_width": float(_.split(" ")[2]), }, } @@ -243,14 +243,14 @@ class MAESTRONanoARPESEndstation(MAESTROARPESEndstationBase): } ATTR_TRANSFORMS: ClassVar[dict] = { - "START_T": lambda l: { - "time": " ".join(l.split(" ")[1:]).lower(), - "date": l.split(" ")[0], + "START_T": lambda _: { + "time": " ".join(_.split(" ")[1:]).lower(), + "date": _.split(" ")[0], }, - "SF_SLITN": lambda l: { - "slit_number": int(l.split(" ")[0]), - "slit_shape": l.split(" ")[-1].lower(), - "slit_width": float(l.split(" ")[2]), + "SF_SLITN": lambda _: { + "slit_number": int(_.split(" ")[0]), + "slit_shape": _.split(" ")[-1].lower(), + "slit_width": float(_.split(" ")[2]), }, } diff --git a/arpes/endstations/plugin/kaindl.py b/arpes/endstations/plugin/kaindl.py index e63cec97..6dbd8dce 100644 --- a/arpes/endstations/plugin/kaindl.py +++ b/arpes/endstations/plugin/kaindl.py @@ -156,7 +156,7 @@ def concatenate_frames( axis_name = lines[0].strip() axis_name = self.RENAME_KEYS.get(axis_name, axis_name) - values = [float(l.strip()) for l in lines[1 : len(frames) + 1]] + values = [float(_.strip()) for _ in lines[1 : len(frames) + 1]] for v, f in zip(values, frames): f.coords[axis_name] = v @@ -239,9 +239,9 @@ def attach_attr(data, attr_name, as_name): data.attrs[angle_attr] = np.deg2rad(float(data.attrs[angle_attr])) ls = [data, *data.S.spectra] - for l in ls: - l.coords["x"] = np.nan - l.coords["y"] = np.nan - l.coords["z"] = np.nan + for _ in ls: + _.coords["x"] = np.nan + _.coords["y"] = np.nan + _.coords["z"] = np.nan return super().postprocess_final(data, scan_desc) diff --git a/arpes/endstations/plugin/merlin.py b/arpes/endstations/plugin/merlin.py index 8de14b9c..7639530b 100644 --- a/arpes/endstations/plugin/merlin.py +++ b/arpes/endstations/plugin/merlin.py @@ -89,15 +89,15 @@ class BL403ARPESEndstation(SynchrotronEndstation, HemisphericalEndstation, SESEn } ATTR_TRANSFORMS: ClassVar[dict[str, Any]] = { - "acquisition_mode": lambda l: l.lower(), - "lens_mode": lambda l: { + "acquisition_mode": lambda _: _.lower(), + "lens_mode": lambda _: { "lens_mode": None, - "lens_mode_name": l, + "lens_mode_name": _, }, "undulator_polarization": int, - "region_name": lambda l: { - "daq_region_name": l, - "daq_region": l, + "region_name": lambda _: { + "daq_region_name": _, + "daq_region": _, }, } @@ -125,12 +125,12 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = N Path(original_filename).parent / f"{internal_match.groups()[0]}_Motor_Pos.txt", ) try: - with open(motors_path) as f: + with Path(motors_path).open() as f: lines = f.readlines() axis_name = lines[0].strip() axis_name = self.RENAME_KEYS.get(axis_name, axis_name) - values = [float(l.strip()) for l in lines[1 : len(frames) + 1]] + values = [float(_.strip()) for _ in lines[1 : len(frames) + 1]] for v, f in zip(values, frames): f.coords[axis_name] = v @@ -139,10 +139,10 @@ def concatenate_frames(self, frames=list[xr.Dataset], scan_desc: dict | None = N for frame in frames: # promote x, y, z to coords so they get concatted - for l in [frame, *frame.S.spectra]: + for _ in [frame, *frame.S.spectra]: for c in ["x", "y", "z"]: - if c not in l.coords: - l.coords[c] = l.attrs[c] + if c not in _.coords: + _.coords[c] = _.attrs[c] return xr.concat(frames, axis_name, coords="different") except Exception: @@ -251,23 +251,23 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None): """ ls = [data, *data.S.spectra] - for l in ls: - if "slit_number" in l.attrs: + for dat in ls: + if "slit_number" in dat.attrs: slit_lookup = { 1: ("straight", 0.1), 7: ("curved", 0.5), } - shape, width = slit_lookup.get(l.attrs["slit_number"], (None, None)) - l.attrs["slit_shape"] = shape - l.attrs["slit_width"] = width + shape, width = slit_lookup.get(dat.attrs["slit_number"], (None, None)) + dat.attrs["slit_shape"] = shape + dat.attrs["slit_width"] = width - if "undulator_polarization" in l.attrs: + if "undulator_polarization" in dat.attrs: phase_angle_lookup = {0: (0, 0), 2: (np.pi / 2, 0)} # LH # LV polarization_theta, polarization_alpha = phase_angle_lookup[ - int(l.attrs["undulator_polarization"]) + int(dat.attrs["undulator_polarization"]) ] - l.attrs["probe_polarization_theta"] = polarization_theta - l.attrs["probe_polarization_alpha"] = polarization_alpha + dat.attrs["probe_polarization_theta"] = polarization_theta + dat.attrs["probe_polarization_alpha"] = polarization_alpha deg_to_rad_coords = {"theta", "phi", "beta", "chi", "psi"} deg_to_rad_attrs = {"theta", "beta", "chi", "psi", "alpha"} @@ -277,9 +277,9 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None): data.coords[c] = data.coords[c] * np.pi / 180 for angle_attr in deg_to_rad_attrs: - for l in ls: - if angle_attr in l.attrs: - l.attrs[angle_attr] = float(l.attrs[angle_attr]) * np.pi / 180 + for dat in ls: + if angle_attr in dat.attrs: + dat.attrs[angle_attr] = np.deg2rad(float(dat.attrs[angle_attr])) data.attrs["alpha"] = np.pi / 2 data.attrs["psi"] = 0 @@ -287,13 +287,17 @@ def postprocess_final(self, data: xr.Dataset, scan_desc: dict | None = None): s.attrs["alpha"] = np.pi / 2 s.attrs["psi"] = 0 - # TODO: Conrad think more about why sometimes individual attrs don't make it onto - # .spectrum.attrs, for now just paste them over + # TODO: Conrad think more about why sometimes individual attrs don't + # make it onto .spectrum.attrs, for now just paste them over necessary_coord_names = {"theta", "beta", "chi", "phi"} ls = data.S.spectra - for l in ls: + for spectrum in ls: for cname in necessary_coord_names: - if cname not in l.attrs and cname not in l.coords and cname in data.attrs: - l.attrs[cname] = data.attrs[cname] + if ( + cname not in spectrum.attrs + and cname not in spectrum.coords + and cname in data.attrs + ): + spectrum.attrs[cname] = data.attrs[cname] return super().postprocess_final(data, scan_desc) diff --git a/arpes/experiment/__init__.py b/arpes/experiment/__init__.py index c7a2bdf2..5aae1236 100644 --- a/arpes/experiment/__init__.py +++ b/arpes/experiment/__init__.py @@ -26,7 +26,7 @@ def flatten(lists): - return chain.from_iterable([l if np.iterable(l) else [l] for l in lists]) + return chain.from_iterable([_ if np.iterable(_) else [_] for _ in lists]) class ExperimentTreeItem: diff --git a/arpes/fits/lmfit_html_repr.py b/arpes/fits/lmfit_html_repr.py index a0a42061..dd419f1e 100644 --- a/arpes/fits/lmfit_html_repr.py +++ b/arpes/fits/lmfit_html_repr.py @@ -30,12 +30,12 @@ def repr_multiline_ModelResult(self: model.Model, **kwargs: Incomplete) -> str: [(" " * 4) + c._repr_multiline_text_() for c in self.components], ), parameters="\n".join( - f" {l}" for l in self.params._repr_multiline_text_(**kwargs).split("\n") + f" {l_item}" for l_item in self.params._repr_multiline_text_(**kwargs).split("\n") ), ) -def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str: +def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str: # noqa: N802 """Provides a better Jupyter representation of an `lmfit.ModelResult` instance.""" template = """
@@ -51,7 +51,7 @@ def repr_html_ModelResult(self: Incomplete, **kwargs: Incomplete) -> str: ) -def repr_html_Model(self: Incomplete) -> str: +def repr_html_Model(self: Incomplete) -> str: # noqa: N802 """Better Jupyter representation of `lmfit.Model` instances.""" template = """
@@ -61,7 +61,7 @@ def repr_html_Model(self: Incomplete) -> str: return template.format(name=self.name) -def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str: +def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str: # noqa: N802 """Provides a text-based multiline representation used in Qt based interactive tools.""" return self.name @@ -70,7 +70,7 @@ def repr_multiline_Model(self: Incomplete, **kwargs: Incomplete) -> str: SKIP_ON_SHORT = {"min", "max", "vary", "expr", "brute_step"} -def repr_html_Parameters(self: Incomplete, *, short: bool = False) -> str: +def repr_html_Parameters(self: Incomplete, *, short: bool = False) -> str: # noqa: N802 """HTML representation for `lmfit.Parameters` instances.""" keys = sorted(self.keys()) template = """ diff --git a/arpes/fits/utilities.py b/arpes/fits/utilities.py index 06c69d86..275c91f6 100644 --- a/arpes/fits/utilities.py +++ b/arpes/fits/utilities.py @@ -18,7 +18,6 @@ import dill import numpy as np import xarray as xr -from packaging import version from tqdm.notebook import tqdm import arpes.fits.fit_models @@ -40,34 +39,24 @@ TypeIterable = list[type] | tuple[type] -XARRAY_REQUIRES_VALUES_WRAPPING = version.parse(xr.__version__) > version.parse("0.10.0") - - -def wrap_for_xarray_values_unpacking(item): - """This is a shim for https://github.com/pydata/xarray/issues/2097.""" - if XARRAY_REQUIRES_VALUES_WRAPPING: - return np.array(item, dtype=object) - - return item - def result_to_hints( - m: lmfit.model.ModelResult | None, + model_result: lmfit.model.ModelResult | None, defaults=None, ) -> dict[str, dict[str, Any]] | None: """Turns an `lmfit.model.ModelResult` into a dictionary with initial guesses. Args: - m: The model result to extract parameters from - defaults: Returned if `m` is None, useful for cell re-evaluation in Jupyter + model_result: The model result to extract parameters from + defaults: Returned if `model_result` is None, useful for cell re-evaluation in Jupyter Returns: A dict containing parameter specifications in key-value rathan than `lmfit.Parameter` format, as you might pass as `params=` to PyARPES fitting code. """ - if m is None: + if model_result is None: return defaults - return {k: {"value": m.params[k].value} for k in m.params} + return {k: {"value": model_result.params[k].value} for k in model_result.params} def parse_model(model): @@ -98,7 +87,7 @@ def parse_model(model): special = set(pad_all) - def read_token(token): + def read_token(token: str) -> str | float: if token in special: return token try: @@ -127,7 +116,7 @@ def broadcast_model( *, progress: bool = True, safe: bool = False, - trace: Callable = None, # type: ignore # noqa: RUF013 + trace: Callable = None, # noqa: RUF013 ) -> xr.Dataset: """Perform a fit across a number of dimensions. @@ -174,7 +163,7 @@ def broadcast_model( n_fits = np.prod(np.array(list(template.S.dshape.values()))) if parallelize is None: - parallelize = bool(n_fits > 20) + parallelize = bool(n_fits > 20) # noqa: PLR2004 trace("Copying residual") residual = data_array.copy(deep=True) @@ -189,7 +178,16 @@ def broadcast_model( wrap_progress = tqdm else: - def wrap_progress(x: Iterable[int], *_, **__) -> Iterable[int]: + def wrap_progress(x: Iterable[int], **__: str | float) -> Iterable[int]: + """Fake of tqdm.notebook.tqdm. + + Args: + x (Iterable[int]): [TODO:description] + __: its a dummy parameter, which is not used. + + Returns: + Same iterable. + """ return x serialize = parallelize @@ -231,7 +229,7 @@ def wrap_progress(x: Iterable[int], *_, **__) -> Iterable[int]: if serialize: trace("Deserializing...") - def unwrap(result_data): + def unwrap(result_data) -> object: # using the lmfit deserialization and serialization seems slower than double pickling # with dill return dill.loads(result_data) @@ -240,7 +238,7 @@ def unwrap(result_data): trace("Finished running fits Collating") for fit_result, fit_residual, coords in exe_results: - template.loc[coords] = wrap_for_xarray_values_unpacking(fit_result) + template.loc[coords] = np.array(fit_result) residual.loc[coords] = fit_residual trace("Bundling into dataset") diff --git a/arpes/fits/zones.py b/arpes/fits/zones.py index 172ead27..be720e1e 100644 --- a/arpes/fits/zones.py +++ b/arpes/fits/zones.py @@ -29,15 +29,24 @@ if TYPE_CHECKING: from numpy.typing import NDArray + from arpes._typing import DataType + def k_points_residual( - paramters, - coords_dataset, - high_symmetry_points, + coords_dataset: DataType, dimensionality: int = 2, ) -> NDArray[np.float_]: + """[TODO:summary]. + + Args: + coords_dataset: [TODO:description] + dimensionality: [TODO:description] + + Returns: + [TODO:description] + """ momentum_coordinates = convert_coordinates(coords_dataset) - if dimensionality == 2: + if dimensionality == 2: # noqa: PLR2004 return np.asarray( [ np.diagonal(momentum_coordinates.kx.values), diff --git a/arpes/laue/__init__.py b/arpes/laue/__init__.py index 38415d74..7c60090d 100644 --- a/arpes/laue/__init__.py +++ b/arpes/laue/__init__.py @@ -19,7 +19,7 @@ from pathlib import Path import numpy as np -import xarray +import xarray as xr from arpes.provenance import provenance_from_file @@ -34,7 +34,10 @@ ) -def load_laue(path: Path | str): +__all__ = ("load_laue",) + + +def load_laue(path: Path | str) -> xr.DataArray: """Loads NorthStart Laue backscattering data.""" if isinstance(path, str): path = Path(path) @@ -45,7 +48,7 @@ def load_laue(path: Path | str): table = np.fromstring(table, dtype=np.uint16).reshape(256, 256) header = np.fromstring(header, dtype=northstar_62_69_dtype).item() - arr = xarray.DataArray( + arr = xr.DataArray( table, coords={"x": np.array(range(256)), "y": np.array(range(256))}, dims=[ diff --git a/arpes/plotting/__init__.py b/arpes/plotting/__init__.py index 1076cd5e..c65df890 100644 --- a/arpes/plotting/__init__.py +++ b/arpes/plotting/__init__.py @@ -1,2 +1,40 @@ """Standard plotting routines and utility code for ARPES analyses.""" -from __future__ import annotations +from __future__ import annotations # noqa: I001 + + +""" +from .annotations import annotate_cuts, annotate_experimental_conditions, annotate_point +from .band_tool import BandTool +from .bands import plot_with_bands +from .basic import make_reference_plots +from .curvature_tool import CurvatureTool +from .comparison_tool import compare +from .dispersion import ( + cut_dispersion_plot, + fancy_dispersion, + hv_reference_scan, + labeled_fermi_surface, + plot_dispersion, + reference_scan_fermi_surface, + scan_var_reference_plot, +) +from .dos import plot_core_levels, plot_dos +from .dyn_tool import DynamicTool, dyn +from .fermi_edge import fermi_edge, plot_fit +from .fermi_surface import fermi_surface_slices, magnify_circular_regions_plot +from .fit_inspection_tool import FitCheckTool +from .interactive import ImageTool +from .mask_tool import mask +from .movie import plot_movie +from .parameter import plot_parameter +from .path_tool import path_tool +from .spatial import plot_spatial_reference, reference_scan_spatial +from .spin import spin_polarized_spectrum, spin_colored_spectrum, spin_difference_spectrum +from .qt_tool import qt_tool +from .qt_ktool import ktool +from .utils import ( + savefig, + remove_colorbars, + fancy_labels, +) + """ diff --git a/arpes/plotting/bands.py b/arpes/plotting/bands.py index 78f5a07f..8a7d8900 100644 --- a/arpes/plotting/bands.py +++ b/arpes/plotting/bands.py @@ -13,11 +13,10 @@ from pathlib import Path from _typeshed import Incomplete + from build.lib.arpes.typing import DataType from matplotlib.axes import Axes from matplotlib.colors import Normalize - from build.lib.arpes.typing import DataType - __all__ = ("plot_with_bands",) diff --git a/arpes/plotting/bz_tool/__init__.py b/arpes/plotting/bz_tool/__init__.py index d478ba94..72d2534b 100644 --- a/arpes/plotting/bz_tool/__init__.py +++ b/arpes/plotting/bz_tool/__init__.py @@ -18,6 +18,7 @@ from arpes.utilities.ui import combo_box, horizontal, tabs from .CoordinateOffsetWidget import CoordinateOffsetWidget +from .RangeOrSingleValueWidget import RangeOrSingleValueWidget __all__ = [ "bz_tool", diff --git a/arpes/plotting/curvature_tool.py b/arpes/plotting/curvature_tool.py index 59cc010c..339c8195 100644 --- a/arpes/plotting/curvature_tool.py +++ b/arpes/plotting/curvature_tool.py @@ -39,7 +39,6 @@ def tool_handler(self, doc): default_palette = self.default_palette - x_coords, y_coords = self.arr.coords[self.arr.dims[1]], self.arr.coords[self.arr.dims[0]] self.app_context.update( { "data": self.arr, @@ -112,7 +111,6 @@ def tool_handler(self, doc): figures["curvature"].yaxis.major_label_text_font_size = "0pt" - # TODO: add support for color mapper plots["d2"] = figures["d2"].image( [self.arr.values], x=data_range["x"][0], diff --git a/arpes/plotting/fit_tool/__init__.py b/arpes/plotting/fit_tool/__init__.py index a4edfba8..d4feba13 100644 --- a/arpes/plotting/fit_tool/__init__.py +++ b/arpes/plotting/fit_tool/__init__.py @@ -273,7 +273,7 @@ def generate_fit_marginal_for( if layout is None: layout = self._layout - remaining_dims = [l for l in list(range(len(self.data.dims))) if l not in dimensions] + remaining_dims = [_ for _ in list(range(len(self.data.dims))) if _ not in dimensions] # for now, we only allow a single fit dimension widget = FitInspectionPlot(name=name, root=weakref.ref(self), orientation=orientation) @@ -390,9 +390,9 @@ def safe_slice(vlow, vhigh, axis=0): for_plot = for_plot.mean(list(select_coord.keys())) cursors = [ - l - for l in reactive.view.getPlotItem().items - if isinstance(l, CursorRegion) + _ + for _ in reactive.view.getPlotItem().items + if isinstance(_, CursorRegion) ] reactive.view.clear() for c in cursors: diff --git a/arpes/plotting/qt_ktool/__init__.py b/arpes/plotting/qt_ktool/__init__.py index c43451d0..0eea42e1 100644 --- a/arpes/plotting/qt_ktool/__init__.py +++ b/arpes/plotting/qt_ktool/__init__.py @@ -40,7 +40,7 @@ class KTool(SimpleApp): DEFAULT_COLORMAP = "viridis" - def __init__(self, apply_offsets=True, zone=None, **kwargs: Incomplete) -> None: + def __init__(self, apply_offsets: bool = True, zone=None, **kwargs: Incomplete) -> None: """Set attributes to safe defaults and unwrap the Brillouin zone definition.""" super().__init__() diff --git a/arpes/plotting/qt_tool/__init__.py b/arpes/plotting/qt_tool/__init__.py index 1a64f4c9..0ec61173 100644 --- a/arpes/plotting/qt_tool/__init__.py +++ b/arpes/plotting/qt_tool/__init__.py @@ -398,9 +398,9 @@ def safe_slice(vlow, vhigh, axis=0): for_plot = for_plot.mean(list(select_coord.keys())) cursors = [ - l - for l in reactive.view.getPlotItem().items - if isinstance(l, CursorRegion) + _ + for _ in reactive.view.getPlotItem().items + if isinstance(_, CursorRegion) ] reactive.view.clear() for c in cursors: diff --git a/arpes/utilities/qt/app.py b/arpes/utilities/qt/app.py index b451971f..a591f705 100644 --- a/arpes/utilities/qt/app.py +++ b/arpes/utilities/qt/app.py @@ -146,7 +146,7 @@ def generate_marginal_for( if layout is None: layout = self._layout - remaining_dims = [l for l in list(range(len(self.data.dims))) if l not in dimensions] + remaining_dims = [dim for dim in list(range(len(self.data.dims))) if dim not in dimensions] if len(remaining_dims) == 1: widget = DataArrayPlot(name=name, root=weakref.ref(self), orientation=orientation) diff --git a/arpes/workflow.py b/arpes/workflow.py index 5f3ebceb..10d8aa9c 100644 --- a/arpes/workflow.py +++ b/arpes/workflow.py @@ -127,25 +127,25 @@ def _read_pickled(self, name: str, default=None): except FileNotFoundError: return default - def _write_pickled(self, name: str, value): - with open(str(self.path / f"{name}.pickle"), "wb") as f: + def _write_pickled(self, name: str, value) -> None: + with Path(self.path / f"{name}.pickle").open("wb") as f: dill.dump(value, f) @property - def publishers(self): + def publishers(self) -> object: return self._read_pickled("publishers", defaultdict(list)) @publishers.setter - def publishers(self, new_publishers): + def publishers(self, new_publishers) -> None: assert isinstance(new_publishers, dict) self._write_pickled("publishers", new_publishers) @property - def consumers(self): + def consumers(self) -> object: return self._read_pickled("consumers", defaultdict(list)) @consumers.setter - def consumers(self, new_consumers): + def consumers(self, new_consumers) -> None: assert isinstance(new_consumers, dict) self._write_pickled("consumers", new_consumers) @@ -181,7 +181,7 @@ def publish(self, key, data) -> None: self.summarize_consumers(key=key) - def consume(self, key: Hashable, *, subscribe: bool = True): + def consume(self, key: str, *, subscribe: bool = True): if subscribe: context = get_running_context() consumers = self.consumers diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index cb07cd4f..fd9171d7 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -77,7 +77,7 @@ from arpes.utilities.xarray import unwrap_xarray_dict, unwrap_xarray_item if TYPE_CHECKING: - from collections.abc import Callable, Generator, Hashable, Iterator + from collections.abc import Callable, Generator, Hashable, Iterator, Sequence from pathlib import Path import pandas as pd @@ -378,7 +378,7 @@ def select_around_data( stored in the dataarray kFs. Then we could select momentum integrated EDCs in a small window around the fermi momentum for each temperature by using - >>> edcs_at_fermi_momentum = full_data.S.select_around_data({'kp': kFs}, radius={'kp': 0.04}, fast=True) # doctest: +SKIP + >>> edcs = full_data.S.select_around_data({'kp': kFs}, radius={'kp': 0.04}, fast=True) The resulting data will be EDCs for each T, in a region of radius 0.04 inverse angstroms around the Fermi momentum. @@ -523,14 +523,11 @@ def select_around( Returns: The binned selection around the desired point or points. """ - if isinstance(self._obj, xr.Dataset): - msg = "Cannot use select_around on Datasets only DataArrays!" - raise TypeError(msg) - - if mode not in {"sum", "mean"}: - msg = "mode parameter should be either sum or mean." - raise ValueError(msg) - + assert isinstance( + self._obj, + xr.Dataset, + ), "Cannot use select_around on Datasets only DataArrays!" + assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean." if isinstance(point, tuple | list): warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2) point = dict(zip(point, self._obj.dims, strict=True)) @@ -620,8 +617,6 @@ def _calculate_symmetry_points( # if it is projected, we need to calculate its projected coordinates """[TODO:summary]. - [TODO:description] - Args: symmetry_points: [TODO:description] projection_distance: [TODO:description] @@ -808,23 +803,29 @@ def original_parent_scan_name(self) -> str: first_modification = history[-3] df: pd.DataFrame = self._obj.attrs["df"] # "df" means DataFrame of pandas return df[df.id == first_modification["parent_id"]].index[0] - except: + except KeyError: pass return "" @property - def scan_row(self): - df: pd.DataFrame = self._obj.attrs["df"] - sdf = df[df.patl == self._obj.attrs["file"]] - return next(iter(sdf.iterrows())) + def scan_row(self) -> Sequence[int]: + try: + df: pd.DataFrame = self._obj.attrs["df"] + sdf = df[df.patl == self._obj.attrs["file"]] + return next(iter(sdf.iterrows())) + except KeyError: + return [] @property - def df_index(self): + def df_index(self) -> int | None: return self.scan_row[0] @property def df_after(self): - return self._obj.attrs["df"][self._obj.attrs["df"].index > self.df_index] + try: + return self._obj.attrs["df"][self._obj.attrs["df"].index > self.df_index] + except KeyError: + return None def df_until_type( self, @@ -1866,7 +1867,7 @@ def _repr_html_(self) -> str: try: name = self.df_index - except: + except IndexError: if "id" in self._obj.attrs: name = "ID: " + str(self._obj.attrs["id"])[:9] + "..." else: @@ -1994,7 +1995,7 @@ def show(self, *, detached: bool = False, **kwargs: Incomplete) -> None: def show_d2(self, **kwargs: Incomplete) -> None: """Opens the Bokeh based second derivative image tool.""" - from arpes.plotting.all import CurvatureTool + from arpes.plotting.curvature_tool import CurvatureTool curve_tool = CurvatureTool(**kwargs) return curve_tool.make_tool(self._obj) diff --git a/tests/conftest.py b/tests/conftest.py index 3dec17f0..9b4b505b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ """Mocks the analysis environment and provides data fixutres for tests.""" from __future__ import annotations -import os from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, TypedDict @@ -79,7 +78,7 @@ def load(path: str) -> xr.DataArray | xr.Dataset: pieces = path.split("/") set_workspace(pieces[0]) return cache_loader.load_test_scan( - os.path.join(*pieces), + str(Path(path)), location=SCAN_FIXTURE_LOCATIONS[path], ) @@ -91,4 +90,4 @@ def load(path: str) -> xr.DataArray | xr.Dataset: arpes.config.load_plugins() yield sandbox arpes.config.CONFIG["WORKSPACE"] = None - arpes.endstations._ENDSTATION_ALIASES = {} + arpes.endstations._ENDSTATION_ALIASES = {} # noqa: SLF001 diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py index 44935ace..e7a68b84 100644 --- a/tests/test_basic_data_loading.py +++ b/tests/test_basic_data_loading.py @@ -11,10 +11,12 @@ from arpes.utilities.conversion import convert_to_kspace if TYPE_CHECKING: + from collections.abc import Iterable + from _typeshed import Incomplete -def pytest_generate_tests(metafunc: Incomplete): +def pytest_generate_tests(metafunc: Incomplete) -> Incomplete: """[TODO:summary]. [TODO:description] @@ -641,7 +643,7 @@ def test_load_file_and_basic_attributes( assert isinstance(data, xr.Dataset) # assert basic dataset attributes - for attr in ["location"]: # TODO: add spectrum type requirement + for attr in ["location"]: assert attr in data.attrs # assert that all necessary coordinates are present @@ -679,7 +681,7 @@ def test_load_file_and_basic_attributes( assert k assert pytest.approx(data.coords[k].item(), 1e-3) == v - def safefirst(x): + def safefirst(x: Iterable[float]) -> float: with contextlib.suppress(TypeError, IndexError): return x[0] diff --git a/tests/test_direct_and_example_data_loading.py b/tests/test_direct_and_example_data_loading.py index e272f2bb..46b56cb9 100644 --- a/tests/test_direct_and_example_data_loading.py +++ b/tests/test_direct_and_example_data_loading.py @@ -11,10 +11,14 @@ from arpes.io import load_data, load_example_data if TYPE_CHECKING: - from _typeshed import Incomplete + from collections.abc import Generator + from .conftest import Sandbox -def test_load_data(sandbox_configuration: Incomplete) -> None: + +def test_load_data( + sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001 +) -> None: """[TODO:summary]. [TODO:description] @@ -32,7 +36,9 @@ def test_load_data(sandbox_configuration: Incomplete) -> None: assert data.spectrum.shape == (240, 240) -def test_load_data_with_plugin_specified(sandbox_configuration: Incomplete) -> None: +def test_load_data_with_plugin_specified( + sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001 +) -> None: """[TODO:summary]. [TODO:description] @@ -52,7 +58,9 @@ def test_load_data_with_plugin_specified(sandbox_configuration: Incomplete) -> N assert np.all(data.spectrum.values == directly_specified_data.spectrum.values) -def test_load_example_data(sandbox_configuration: Incomplete) -> None: +def test_load_example_data( + sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001 +) -> None: """[TODO:summary]. [TODO:description] diff --git a/tests/test_time_configuration.py b/tests/test_time_configuration.py index faf6d018..590f47b2 100644 --- a/tests/test_time_configuration.py +++ b/tests/test_time_configuration.py @@ -1,10 +1,20 @@ """test for time configuration.""" +from __future__ import annotations + import os.path +from typing import TYPE_CHECKING import arpes.config +if TYPE_CHECKING: + from collections.abc import Generator + + from .conftest import Sandbox + -def test_patched_config(sandbox_configuration) -> None: +def test_patched_config( + sandbox_configuration: Generator[Sandbox, None, None], +) -> None: """[TODO:summary]. [TODO:description] @@ -22,7 +32,9 @@ def test_patched_config(sandbox_configuration) -> None: assert str(arpes.config.CONFIG["WORKSPACE"]["path"]).split(os.sep)[-2:] == ["datasets", "basic"] -def test_patched_config_no_workspace(sandbox_configuration) -> None: +def test_patched_config_no_workspace( + sandbox_configuration: Generator[Sandbox, None, None], # noqa: ARG001 +) -> None: """[TODO:summary]. [TODO:description] diff --git a/tests/utils.py b/tests/utils.py index 70c0743d..dfbf0b82 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,11 +26,9 @@ def path_to_datasets() -> Path: class CachingDataLoader: cache: dict[str, xr.Dataset] = field(default_factory=dict) - def load_test_scan(self, example_name: str | Path, **kwargs: Incomplete) -> xr.Dataset: + def load_test_scan(self, example_name: str, **kwargs: Incomplete) -> xr.Dataset: """[TODO:summary]. - [TODO:description] - Args: example_name ([TODO:type]): [TODO:description] kwargs: Pass to load_data function @@ -49,7 +47,7 @@ def load_test_scan(self, example_name: str | Path, **kwargs: Incomplete) -> xr.D raise ValueError(msg) data = load_data(str(path_to_data.absolute()), **kwargs) - self.cache[example_name] = data + self.cache[str(example_name)] = data return data.copy(deep=True)