From 3175015b487404c4fc4e9b98cdc7e82d4750e21e Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 13 Feb 2024 16:30:03 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/analysis/align.py | 11 +- src/arpes/analysis/deconvolution.py | 193 +------------------ src/arpes/analysis/gap.py | 11 +- src/arpes/analysis/path.py | 3 +- src/arpes/analysis/pocket.py | 39 ++-- src/arpes/analysis/sarpes.py | 9 +- src/arpes/analysis/self_energy.py | 6 +- src/arpes/analysis/shirley.py | 23 ++- src/arpes/endstations/plugin/ALG_spin_ToF.py | 2 - src/arpes/models/band.py | 2 +- src/arpes/preparation/axis_preparation.py | 6 +- src/arpes/utilities/jupyter.py | 13 +- src/arpes/xarray_extensions.py | 19 +- 13 files changed, 81 insertions(+), 256 deletions(-) diff --git a/src/arpes/analysis/align.py b/src/arpes/analysis/align.py index 7d385456..7b9cf0cb 100644 --- a/src/arpes/analysis/align.py +++ b/src/arpes/analysis/align.py @@ -7,6 +7,7 @@ if we need to. I doubt that this is necessary and don't mind the copied code too much at the present. """ + from __future__ import annotations import numpy as np @@ -60,12 +61,9 @@ def align2d(a: xr.DataArray, b: xr.DataArray, *, subpixel: bool = True) -> tuple y, x = true_y, true_x - y = 1.0 * y - a.values.shape[0] / 2.0 - x = 1.0 * x - a.values.shape[1] / 2.0 - return ( - y * a.G.stride(generic_dim_names=False)[a.dims[0]], - x * a.G.stride(generic_dim_names=False)[a.dims[1]], + (float(y) - a.values.shape[0] / 2.0) * a.G.stride(generic_dim_names=False)[a.dims[0]], + (float(x) - a.values.shape[1] / 2.0) * a.G.stride(generic_dim_names=False)[a.dims[1]], ) @@ -93,8 +91,7 @@ def align1d(a: xr.DataArray, b: xr.DataArray, *, subpixel: bool = True) -> float mod = QuadraticModel().guess_fit(marg) x = x + -mod.params["b"].value / (2 * mod.params["a"].value) - x = 1.0 * x - a.values.shape[0] / 2.0 - return x * a.G.stride(generic_dim_names=False)[a.dims[0]] + return (float(x) - a.values.shape[0] / 2.0) * a.G.stride(generic_dim_names=False)[a.dims[0]] def align(a: xr.DataArray, b: xr.DataArray, **kwargs: bool) -> tuple[float, float] | float: diff --git a/src/arpes/analysis/deconvolution.py b/src/arpes/analysis/deconvolution.py index 95a4f754..97db4d04 100644 --- a/src/arpes/analysis/deconvolution.py +++ b/src/arpes/analysis/deconvolution.py @@ -2,25 +2,20 @@ from __future__ import annotations -import contextlib -import warnings -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING import numpy as np import scipy import scipy.ndimage import xarray as xr -from tqdm.notebook import tqdm +from skimage.restoration import richardson_lucy -from arpes.constants import TWO_DIMENSION +import arpes.xarray_extensions # noqa: F401 from arpes.fits.fit_models.functional_forms import gaussian from arpes.provenance import update_provenance from arpes.utilities import normalize_to_spectrum if TYPE_CHECKING: - from collections.abc import Iterable - - from _typeshed import Incomplete from numpy.typing import NDArray @@ -83,170 +78,24 @@ def deconvolve_ice( @update_provenance("Lucy Richardson Deconvolution") def deconvolve_rl( data: xr.DataArray, - psf: xr.DataArray | None = None, + psf: xr.DataArray, n_iterations: int = 10, - axis: str = "", - sigma: float = 0, - mode: Literal["reflect", "constant", "nearest", "mirror", "wrap"] = "reflect", - *, - progress: bool = True, ) -> xr.DataArray: """Deconvolves data by a given point spread function using the Richardson-Lucy (RL) method. Args: data: input data - axis - mode: pass to ndimage.convolve - sigma - progress - psf: for 1d, if not specified, must specify axis and sigma + psf: The point spread function. n_iterations: the number of convolutions to use for the fit Returns: The Richardson-Lucy deconvolved data. """ arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) - - if psf is None and axis != "" and sigma != 0: - # if no psf is provided and we have the information to make a 1d one - # note: this assumes gaussian psf - psf = make_psf1d(data=arr, dim=axis, sigma=sigma) - - if len(data.dims) > 1: - if not axis: - # perform one-dimensional deconvolution of multidimensional data - - # support for progress bars - def wrap_progress( - x: Iterable[int], - *args: Incomplete, - **kwargs: Incomplete, - ) -> Iterable[int]: - if args: - for arg in args: - warnings.warn( - f"unused args is set in deconvolution.py/wrap_progress: {arg}", - stacklevel=2, - ) - if kwargs: - for k, v in kwargs.items(): - warnings.warn( - f"unused args is set in deconvolution.py/wrap_progress: {k}: {v}", - stacklevel=2, - ) - return x - - if progress: - wrap_progress = tqdm - - # dimensions over which to iterate - other_dim = list(data.dims) - other_dim.remove(axis) - - if len(other_dim) == 1: - # two-dimensional data - other_dim = other_dim[0] - result = arr.copy(deep=True).transpose( - other_dim, - axis, - ) - # not sure why the dims only seems to work in this order. - # seems like I should be able to swap it to (axis,other_dim) - # and also change the data collection to result[x_ind,y_ind], - # but this gave different results - - for i, (_, iteration) in wrap_progress( - enumerate(arr.G.iterate_axis(other_dim)), - desc="Iterating " + other_dim, - total=len(arr[other_dim]), - ): # TODO: tidy this gross-looking loop - # indices of data being deconvolved - x_ind = xr.DataArray(list(range(len(arr[axis]))), dims=[axis]) - y_ind = xr.DataArray([i] * len(x_ind), dims=[other_dim]) - # perform deconvolution on this one-dimensional piece - deconv = deconvolve_rl( - data=iteration, - psf=psf, - n_iterations=n_iterations, - axis="", - mode=mode, - ) - # build results out of these pieces - result[y_ind, x_ind] = deconv.values - elif len(other_dim) == TWO_DIMWENSION: - # three-dimensional data - result = arr.copy(deep=True).transpose(*other_dim, axis) - # not sure why the dims only seems to work in this order. - # eems like I should be able to swap it to (axis,*other_dim) and also change the - # data collection to result[x_ind,y_ind,z_ind], but this gave different results - for i, (_od0, iteration0) in wrap_progress( - enumerate(arr.G.iterate_axis(other_dim[0])), - desc="Iterating " + str(other_dim[0]), - total=len(arr[other_dim[0]]), - ): # TODO: tidy this gross-looking loop - for j, (_od1, iteration1) in wrap_progress( - enumerate(iteration0.G.iterate_axis(other_dim[1])), - desc="Iterating " + str(other_dim[1]), - total=len(arr[other_dim[1]]), - leave=False, - ): # TODO: tidy this gross-looking loop - # indices of data being deconvolved - x_ind = xr.DataArray(list(range(len(arr[axis]))), dims=[axis]) - y_ind = xr.DataArray([i] * len(x_ind), dims=[other_dim[0]]) - z_ind = xr.DataArray([j] * len(x_ind), dims=[other_dim[1]]) - # perform deconvolution on this one-dimensional piece - deconv = deconvolve_rl( - data=iteration1, - psf=psf, - n_iterations=n_iterations, - axis="", - mode=mode, - ) - # build results out of these pieces - result[y_ind, z_ind, x_ind] = deconv.values - elif len(other_dim) >= TWO_DIMENSION + 1: - # four- or higher-dimensional data - # TODO: find way to compactify the different dimensionalities rather than having - # separate code - msg = "high-dimensional data not yet supported" - raise NotImplementedError(msg) - elif not axis: - # crude attempt to perform multidimensional deconvolution. - # not clear if this is currently working - # TODO: may be able to do this as a sequence of one-dimensional deconvolutions, assuming - # that the psf is separable (which I think it should be, if we assume it is a - # multivariate gaussian with principle axes aligned with the dimensions) - msg = "multi-dimensional convolutions not yet supported" - raise NotImplementedError(msg) - - if not isinstance(arr, np.ndarray): - arr = arr.values - - u = [arr] - - for i in range(n_iterations): - c = scipy.ndimage.convolve(u[-1], psf, mode=mode) - u.append(u[-1] * scipy.ndimage.convolve(arr / c, np.flip(psf, None), mode=mode)) - # careful about which axis (axes) to flip here...! - # need to explicitly specify for some versions of numpy - - result = u[-1] - else: # data.dims == 1 - if not isinstance(arr, np.ndarray): - arr = arr.values - u = [arr] - for _ in range(n_iterations): - c = scipy.ndimage.convolve(u[-1], psf, mode=mode) - u.append(u[-1] * scipy.ndimage.convolve(arr / c, np.flip(psf, 0), mode=mode)) - # not yet tested to ensure flip correct for asymmetric psf - # note: need to explicitly specify axis number in np.flip in lower versions of numpy - if isinstance(data, np.ndarray): - result = u[-1].copy() - else: - result = data.copy(deep=True) - result.values = u[-1] - with contextlib.suppress(Exception): - return result.transpose(*arr.dims) + data_image = arr.values + psf_ = psf.values + im_deconv = richardson_lucy(data_image, psf_, num_iter=n_iterations, filter_epsilon=None) + return arr.S.with_values(im_deconv) @update_provenance("Make 1D-Point Spread Function") @@ -283,27 +132,3 @@ def make_psf(data: xr.DataArray, sigmas: dict[str, float]) -> xr.DataArray: Returns: The PSF to use. """ - raise NotImplementedError - - arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) - dims = arr.dims - - psf = arr.copy(deep=True) * 0 + 1 - - for dim in dims: - other_dims = list(arr.dims) - other_dims.remove(dim) - - psf1d = arr.copy(deep=True) * 0 + 1 - for od in other_dims: - psf1d = psf1d[{od: 0}] - - if sigmas[dim] == 0: - # TODO: may need to do subpixel correction for when the dimension has an even length - psf1d = psf1d * 0 - psf1d[{dim: len(psf1d.coords[dim]) / 2}] = 1 - else: - psf1d = psf1d * gaussian(psf1d.coords[dim], np.mean(psf1d.coords[dim]), sigmas[dim]) - - psf = psf * psf1d - return psf diff --git a/src/arpes/analysis/gap.py b/src/arpes/analysis/gap.py index fa17a6e1..6d3fc0c9 100644 --- a/src/arpes/analysis/gap.py +++ b/src/arpes/analysis/gap.py @@ -79,7 +79,7 @@ def determine_broadened_fermi_distribution( @update_provenance("Normalize By Fermi Dirac") -def normalize_by_fermi_dirac( +def normalize_by_fermi_dirac( # noqa: PLR0913 data: DataType, reference_data: DataType | None = None, broadening: float = 0, @@ -142,7 +142,7 @@ def normalize_by_fermi_dirac( if (not temperature_axis) and "temp" in data.dims: temperature_axis = "temp" - transpose_order = list(data.dims) + transpose_order: list[str] = [str(dim) for dim in data.dims] transpose_order.remove("eV") if temperature_axis: @@ -190,8 +190,7 @@ def _shift_energy_interpolate( data: xr.DataArray, shift: xr.DataArray | None = None, ) -> xr.DataArray: - if not isinstance(data, xr.DataArray): - data = normalize_to_spectrum(data) + data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) data_arr = data.S.transpose_to_front("eV") new_data = data_arr.copy(deep=True) @@ -211,6 +210,7 @@ def _shift_energy_interpolate( shift = shift - stride * n_strides new_axis = new_axis + shift + assert shift is not None weight = float(shift / stride) new_values = new_values + data_arr.values * (1 - weight) @@ -249,8 +249,7 @@ def symmetrize( Returns: The symmetrized data. """ - if not isinstance(data, xr.DataArray): - data = normalize_to_spectrum(data) + data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) data = data.S.transpose_to_front("eV") if subpixel or full_spectrum: diff --git a/src/arpes/analysis/path.py b/src/arpes/analysis/path.py index 1df113b4..44fa9fd7 100644 --- a/src/arpes/analysis/path.py +++ b/src/arpes/analysis/path.py @@ -1,4 +1,5 @@ """Contains routines used to do path selections and manipulations on a dataset.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -54,7 +55,7 @@ def as_vec(ds: xr.Dataset) -> NDArray[np.float_]: return np.array([ds[k].item() for k in order]) def distance(a: xr.Dataset, b: xr.Dataset) -> float: - return np.linalg.norm((as_vec(a) - as_vec(b)) * scaling) + return float(np.linalg.norm((as_vec(a) - as_vec(b)) * scaling)) length = 0 for idx_low, idx_high in zip(path.index.values, path.index[1:].values, strict=False): diff --git a/src/arpes/analysis/pocket.py b/src/arpes/analysis/pocket.py index 4c0ded11..ae7805f7 100644 --- a/src/arpes/analysis/pocket.py +++ b/src/arpes/analysis/pocket.py @@ -14,11 +14,12 @@ from arpes.utilities.conversion import slice_along_path if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Hashable from _typeshed import Incomplete - from arpes._typing import DataType, XrTypes + from arpes._typing import XrTypes + __all__ = ( "curves_along_pocket", "edcs_along_pocket", @@ -28,7 +29,7 @@ def pocket_parameters( - data: DataType, + data: xr.DataArray, kf_method: Callable[..., float] | None = None, sel: dict[str, slice] | None = None, method_kwargs: Incomplete = None, @@ -89,7 +90,7 @@ def radial_edcs_along_pocket( radii: tuple[float, float] = (0.0, 5.0), n_points: int = 0, select_radius: dict[str, float] | float | None = None, - **kwargs: Incomplete, + **kwargs: float, ) -> xr.Dataset: """Produces EDCs distributed radially along a vector from the pocket center. @@ -107,7 +108,7 @@ def radial_edcs_along_pocket( coordinate. n_points: Number of EDCs, can be automatically inferred. select_radius: The radius used for selections along the radial curve. - kwargs: Used to define the central point. + kwargs: Center point of each dimension. Return: A 2D array which has an angular coordinate around the pocket center. @@ -119,7 +120,7 @@ def radial_edcs_along_pocket( assert "eV" in fermi_surface_dims fermi_surface_dims.remove("eV") - center_point = {k: v for k, v in kwargs.items() if k in data_array.dims} + center_point: dict[Hashable, float] = {k: v for k, v in kwargs.items() if k in data_array.dims} center_as_vector = np.array([center_point.get(d, 0) for d in fermi_surface_dims]) if not n_points: @@ -162,7 +163,7 @@ def curves_along_pocket( n_points: int = 0, inner_radius: float = 0.0, outer_radius: float = 5.0, - **kwargs: Incomplete, + **kwargs: float, ) -> tuple[list[xr.DataArray], list[float]]: """Produces radial slices along a Fermi surface through a pocket. @@ -175,11 +176,10 @@ def curves_along_pocket( Args: data: input data - n_points (int): - inner_radius: - outer_radius: - shape: - kwargs: + n_points: Number of EDCs, can be automatically inferred. + inner_radius: inner radius + outer_radius: outer radius + kwargs: Center point of each dimension. Returns: A tuple of two lists. The first list contains the slices and the second @@ -191,9 +191,11 @@ def curves_along_pocket( if "eV" in fermi_surface_dims: fermi_surface_dims.remove("eV") - center_point = {str(k): v for k, v in kwargs.items() if k in data_array.dims} + center_point: dict[Hashable, float] = {k: v for k, v in kwargs.items() if k in data_array.dims} - center_as_vector = np.array([center_point.get(dim_name, 0) for dim_name in fermi_surface_dims]) + center_as_vector = np.array( + [center_point.get(dim_name, 0.0) for dim_name in fermi_surface_dims], + ) if not n_points: # determine N approximately by the granularity @@ -254,14 +256,13 @@ def find_kf_by_mdc( Returns: The fitting Fermi momentum. """ - slice_arr = ( + slice_data = ( slice_data if isinstance(slice_data, xr.DataArray) else normalize_to_spectrum(slice_data) ) - assert isinstance(slice_arr, xr.DataArray) - - if "eV" in slice_arr.dims: - slice_arr = slice_arr.sum("eV") + assert isinstance(slice_data, xr.DataArray) + if "eV" in slice_data.dims: + slice_arr = slice_data.sum("eV") lor = LorentzianModel() bkg = AffineBackgroundModel(prefix="b_") diff --git a/src/arpes/analysis/sarpes.py b/src/arpes/analysis/sarpes.py index 58804d10..c82a548f 100644 --- a/src/arpes/analysis/sarpes.py +++ b/src/arpes/analysis/sarpes.py @@ -2,17 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import xarray as xr from arpes.provenance import update_provenance from arpes.utilities import normalize_to_dataset from arpes.utilities.math import polarization -if TYPE_CHECKING: - from arpes._typing import DataType - __all__ = ( "to_intensity_polarization", "to_up_down", @@ -21,7 +16,7 @@ @update_provenance("Normalize SARPES by photocurrent") -def normalize_sarpes_photocurrent(data: DataType) -> DataType: +def normalize_sarpes_photocurrent(data: xr.DataArray) -> xr.DataArray: """Normalizes the down channel so that it matches the up channel in terms of mean photocurrent. Destroys the integrity of "count" data because we have scaled individual arrivals. @@ -38,7 +33,7 @@ def normalize_sarpes_photocurrent(data: DataType) -> DataType: @update_provenance("Convert polarization data to up-down spin channels") -def to_up_down(data: DataType) -> xr.Dataset: +def to_up_down(data: xr.Dataset) -> xr.Dataset: """Converts from [intensity, polarization] representation to [up, down] representation. This is the inverse function to `to_intensity_polarization`, neglecting the role of the diff --git a/src/arpes/analysis/self_energy.py b/src/arpes/analysis/self_energy.py index ad9cf030..ec4a34ab 100644 --- a/src/arpes/analysis/self_energy.py +++ b/src/arpes/analysis/self_energy.py @@ -210,8 +210,10 @@ def to_self_energy( The equivalent self energy from the bare band and the measured dispersion. """ if not k_independent: - msg = "PyARPES does not currently support self energy analysis" - msg += " except in the k-independent formalism." + msg = ( + "PyARPES does not currently support self energy analysis" + " except in the k-independent formalism." + ) raise NotImplementedError( msg, ) diff --git a/src/arpes/analysis/shirley.py b/src/arpes/analysis/shirley.py index 8c4449f3..8306addf 100644 --- a/src/arpes/analysis/shirley.py +++ b/src/arpes/analysis/shirley.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict, Unpack import numpy as np import xarray as xr @@ -22,8 +22,18 @@ ) +class KwargsCalShirleyBGFunc(TypedDict, total=False): + energy_range: slice | None + eps: float + max_iters: int + n_samples: int + + @update_provenance("Remove Shirley background") -def remove_shirley_background(xps: xr.DataArray, **kwargs: float) -> xr.DataArray: +def remove_shirley_background( + xps: xr.DataArray, + **kwargs: Unpack[KwargsCalShirleyBGFunc], +) -> xr.DataArray: """Calculates and removes a Shirley background from a spectrum. Only the background corrected spectrum is retrieved. @@ -49,7 +59,6 @@ def _calculate_shirley_background_full_range( background = np.copy(xps) cumulative_xps = np.cumsum(xps, axis=0) total_xps = np.sum(xps, axis=0) - rel_error = np.inf i_left = np.mean(xps[:n_samples], axis=0) @@ -78,11 +87,11 @@ def _calculate_shirley_background_full_range( break if (iter_count + 1) == max_iters: - warnings.warn( - "Shirley background calculation did not converge ", - f"after {max_iters} steps with relative error {rel_error}!", - stacklevel=2, + msg = ( + "Shirley background calculation did not converge " + f"after {max_iters} steps with relative error {rel_error}!" ) + warnings.warn(msg, stacklevel=2) return background diff --git a/src/arpes/endstations/plugin/ALG_spin_ToF.py b/src/arpes/endstations/plugin/ALG_spin_ToF.py index 257fe5be..f1ed6c22 100644 --- a/src/arpes/endstations/plugin/ALG_spin_ToF.py +++ b/src/arpes/endstations/plugin/ALG_spin_ToF.py @@ -279,8 +279,6 @@ def load(self, scan_desc: SCANDESC) -> xr.Dataset: ) msg = "Expected a dictionary of scan_desc with the location of the file" raise TypeError(msg) - if kwargs: - warnings.warn("kwargs are not supported.", stacklevel=2) data_loc = scan_desc.get("path", scan_desc.get("file")) scan_desc = { k: v for k, v in scan_desc.items() if not isinstance(v, float) or not np.isnan(v) diff --git a/src/arpes/models/band.py b/src/arpes/models/band.py index 1ec829c1..fb6c85df 100644 --- a/src/arpes/models/band.py +++ b/src/arpes/models/band.py @@ -64,7 +64,7 @@ def velocity(self) -> xr.DataArray: spacing = float(self.coords[self.dims[0]][1] - self.coords[self.dims[0]][0]) def embed_nan(values: NDArray[np.float_], padding: int) -> NDArray[np.float_]: - embedded = np.ndarray((values.shape[0] + 2 * padding,)) + embedded: NDArray[np.float_] = np.ndarray((values.shape[0] + 2 * padding,)) embedded[:] = float("nan") embedded[padding:-padding] = values return embedded diff --git a/src/arpes/preparation/axis_preparation.py b/src/arpes/preparation/axis_preparation.py index af7ba4e6..4a98f35e 100644 --- a/src/arpes/preparation/axis_preparation.py +++ b/src/arpes/preparation/axis_preparation.py @@ -102,7 +102,7 @@ def flip_axis( @lift_dataarray_to_generic def normalize_dim( - arr: XrTypes, + arr: xr.DataArray, dim_or_dims: str | list[str], *, keep_id: bool = False, @@ -161,14 +161,14 @@ def normalize_total(data: XrTypes, *, total_intensity: float = 1000000) -> xr.Da def dim_normalizer( dim_name: str, -) -> Callable[[XrTypes], XrTypes]: +) -> Callable[[xr.DataArray], xr.DataArray]: """Safe partial application of dimension normalization. Args: dim_name (str): [TODO:description] """ - def normalize(arr: XrTypes) -> XrTypes: + def normalize(arr: xr.DataArray) -> xr.DataArray: if dim_name not in arr.dims: return arr return normalize_dim(arr, dim_name) diff --git a/src/arpes/utilities/jupyter.py b/src/arpes/utilities/jupyter.py index 33980170..d3f4f14c 100644 --- a/src/arpes/utilities/jupyter.py +++ b/src/arpes/utilities/jupyter.py @@ -125,7 +125,13 @@ def get_recent_history(n_items: int = 10) -> list[str]: ipython = get_ipython() assert isinstance(ipython, InteractiveShell) return [ - _[-1] for _ in list(ipython.history_manager.get_tail(n=n_items, include_latest=True)) + _[-1] + for _ in list( + ipython.history_manager.get_tail( # type: ignore [union-attr] + n=n_items, + include_latest=True, + ), + ) ] except (ImportError, AttributeError, AssertionError): return ["No accessible history."] @@ -153,7 +159,10 @@ def get_recent_logs(n_bytes: int = 1000) -> list[str]: lines = file.readlines() # ensure we get the most recent information - final_cell = ipython.history_manager.get_tail(n=1, include_latest=True)[0][-1] + final_cell = ipython.history_manager.get_tail( # type: ignore [union-attr] + n=1, + include_latest=True, + )[0][-1] return [_.decode() for _ in lines] + [final_cell] except (ImportError, AttributeError, AssertionError): diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index e47f76d0..e03d3e96 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -266,13 +266,8 @@ def polarization(self) -> float | str | tuple[float, float]: def is_subtracted(self) -> bool: """Infers whether a given data is subtracted. - Args: - self ([TODO:type]): [TODO:description] - Returns (bool): Return True if the data is subtracted. - - ToDo: Need test """ assert isinstance(self._obj, xr.DataArray) if self._obj.attrs.get("subtracted"): @@ -358,13 +353,8 @@ def with_values(self, new_values: NDArray[np.float_]) -> xr.DataArray: def logical_offsets(self) -> dict[str, float | xr.DataArray]: """Return logical offsets. - Raises: - ValueError: [TODO:description] - Returns: - [TODO:description] - - ToDo: Test + dict object of long_* + physical_long_* (*: x, y, or z) """ assert isinstance(self._obj, xr.DataArray | xr.Dataset) if "long_x" not in self._obj.coords: @@ -1734,9 +1724,9 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]: if isinstance(v, xr.DataArray): min_hv = float(v.min()) max_hv = float(v.max()) - transformed_dict[k] = ( - f" from {min_hv} to {max_hv} eV" - ) + transformed_dict[ + k + ] = f" from {min_hv} to {max_hv} eV" elif isinstance(v, float) and not np.isnan(v): transformed_dict[k] = f"{v} eV" return transformed_dict @@ -3077,7 +3067,6 @@ def __init__(self, xarray_obj: xr.DataArray) -> None: self._obj = xarray_obj class _PlotParamKwargs(MPLPlotKwargs, total=False): - ax: Axes | None shift: float x_shift: float