diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index acde5e3b..e71784b3 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -176,7 +176,7 @@ def along( self, directions: list[Hashable | dict[Hashable, float]], **kwargs: Unpack[_SliceAlongPathKwags], - ) -> xr.Dataset: + ) -> xr.Dataset: # TODO: [RA] xr.DataArray """[TODO:summary]. Args: @@ -210,7 +210,7 @@ def sherman_function(self) -> float: Raises: ValueError When no Sherman function related value is found. - ToDo: Test + ToDo: Test, Consider if it should be in "S" """ for option in ["sherman", "sherman_function", "SHERMAN"]: if option in self._obj.attrs: @@ -254,7 +254,7 @@ def polarization(self) -> float | str | tuple[float, float]: return np.nan @property - def is_subtracted(self) -> bool: + def is_subtracted(self) -> bool: # TODO: [RA] xr.DataArray """Infers whether a given data is subtracted. Returns (bool): @@ -275,6 +275,7 @@ def is_spatial(self) -> bool: True if the data is explicltly a "ucut" or "spem" or contains "X", "Y", or "Z" dimensions. False otherwise. """ + assert isinstance(self._obj, xr.DataArray | xr.Dataset) if self.spectrum_type in {"ucut", "spem"}: return True @@ -294,7 +295,7 @@ def is_kspace(self) -> bool: return not any(d in {"phi", "theta", "beta", "angle"} for d in self._obj.dims) @property - def is_slit_vertical(self) -> bool: + def is_slit_vertical(self) -> bool: # TODO: [RA] Refactoring ? """Infers whether the scan is taken on an analyzer with vertical slit. Caveat emptor: this assumes that the alpha coordinate is not some intermediate value. @@ -320,7 +321,10 @@ def endstation(self) -> str: """ return str(self._obj.attrs["location"]) - def with_values(self, new_values: NDArray[np.float_]) -> xr.DataArray: + def with_values( + self, + new_values: NDArray[np.float_], + ) -> xr.DataArray: # TODO: [RA] xr.DataArray """Copy with new array values. Easy way of creating a DataArray that has the same shape as the calling object but data @@ -497,10 +501,8 @@ def select_around_data( ), "Cannot use select_around on Datasets only DataArrays!" assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean." + assert isinstance(points, dict | xr.Dataset) radius = radius or {} - if isinstance(points, tuple | list): - warnings.warn("Dangerous iterable points argument to `select_around`", stacklevel=2) - points = dict(zip(self._obj.dims, points, strict=True)) if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} assert isinstance(points, dict) @@ -590,10 +592,7 @@ def select_around( ), "Cannot use select_around on Datasets only DataArrays!" assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean." - - if isinstance(points, tuple | list): - warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2) - points = dict(zip(self._obj.dims, points, strict=True)) + assert isinstance(points, dict | xr.Dataset) if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} logger.debug(f"points: {points}") @@ -856,7 +855,11 @@ def inner_potential(self) -> float: return self._obj.attrs["inner_potential"] return 10 - def find_spectrum_energy_edges(self, *, indices: bool = False) -> NDArray[np.float_]: + def find_spectrum_energy_edges( + self, + *, + indices: bool = False, + ) -> NDArray[np.float_]: # TODO: xr.DataArray """Return energy position corresponding to the (1D) spectrum edge. Spectrum edge is infection point of the peak. @@ -962,7 +965,7 @@ def zero_spectrometer_edges( interp_range: float | None = None, low: Sequence[float] | NDArray[np.float_] | None = None, high: Sequence[float] | NDArray[np.float_] | None = None, - ) -> xr.DataArray: + ) -> xr.DataArray: # TODO: [RA] xr.DataArray assert isinstance(self._obj, xr.DataArray) if low is not None: assert high is not None @@ -1047,7 +1050,7 @@ def find_spectrum_angular_edges( *, angle_name: str = "phi", indices: bool = False, - ) -> NDArray[np.float_] | NDArray[np.int_]: + ) -> NDArray[np.float_] | NDArray[np.int_]: # TODO: [RA] xr.DataArray """Return angle position corresponding to the (1D) spectrum edge. Args: @@ -1877,11 +1880,11 @@ def fs_plot( **kwargs: Unpack[LabeledFermiSurfaceParam], ) -> Path | tuple[Figure | None, Axes]: """Provides a reference plot of the approximate Fermi surface.""" + assert isinstance(self._obj, xr.DataArray) out = kwargs.get("out") if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fs") kwargs["out"] = out - assert isinstance(self._obj, xr.DataArray) return labeled_fermi_surface(self._obj, **kwargs) def fermi_edge_reference_plot( @@ -1900,9 +1903,9 @@ def fermi_edge_reference_plot( Returns: [TODO:description] """ + assert isinstance(self._obj, xr.DataArray) if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fermi_edge_reference") - assert isinstance(self._obj, xr.DataArray) return fermi_edge_reference(self._obj, out=out, **kwargs) def _referenced_scans_for_spatial_plot( @@ -2179,9 +2182,7 @@ def correct_angle_by( NORMALIZED_DIM_NAMES = ["x", "y", "z", "w"] -@xr.register_dataset_accessor("G") -@xr.register_dataarray_accessor("G") -class GenericAccessorTools: +class GenericAccessorBase: _obj: XrTypes def round_coordinates( @@ -2488,19 +2489,33 @@ def iterate_axis( coords_dict = dict(zip(axis_name_or_axes, cut_coords, strict=True)) yield coords_dict, self._obj.sel(coords_dict, method="nearest") - # --------- + +@xr.register_dataset_accessor("G") +class GenericDatasetAccessor(GenericAccessorBase): def filter_vars( self, f: Callable[[Hashable, xr.DataArray], bool], - ) -> xr.Dataset: # TODO: [RA] Dataset only + ) -> xr.Dataset: assert isinstance(self._obj, xr.Dataset) # ._obj.data_vars return xr.Dataset( data_vars={k: v for k, v in self._obj.data_vars.items() if f(k, v)}, attrs=self._obj.attrs, ) - # ---------- - def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray + def __init__(self, xarray_obj: xr.Dataset) -> None: + """Initialization hook for xarray. + + This should never need to be called directly. + + Args: + xarray_obj: The parent object which this is an accessor for + """ + self._obj = xarray_obj + + +@xr.register_dataarray_accessor("G") +class GenericDataArrayAccessor(GenericAccessorBase): + def argmax_coords(self) -> dict[Hashable, float]: """Return dict representing the position for maximum value.""" assert isinstance(self._obj, xr.DataArray) data: xr.DataArray = self._obj @@ -2510,7 +2525,7 @@ def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray flat_indices = np.unravel_index(idx, data.values.shape) return {d: data.coords[d][flat_indices[i]].item() for i, d in enumerate(data.dims)} - def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]: # TODO: [RA] DataArray + def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]: """Converts to a flat representation where the coordinate values are also present. Extremely valuable for plotting a dataset with coordinates, X, Y and values Z(X,Y) @@ -2542,7 +2557,7 @@ def meshgrid( self, *, as_dataset: bool = False, - ) -> dict[Hashable, NDArray[np.float_]] | xr.Dataset: # TODO: [RA] DataArray + ) -> dict[Hashable, NDArray[np.float_]] | xr.Dataset: assert isinstance(self._obj, xr.DataArray) # ._obj.values is used. dims = self._obj.dims @@ -2564,7 +2579,7 @@ def meshgrid( return meshed_coordinates - def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]: # TODO: [RA] DataArray + def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]: """Converts a (1D) `xr.DataArray` into two plain ``ndarray``s of their coordinate and data. Useful for rapidly converting into a format than can be `plt.scatter`ed @@ -2583,7 +2598,7 @@ def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]: # TODO: [ return (self._obj.coords[self._obj.dims[0]].values, self._obj.values) - def clean_outliers(self, clip: float = 0.5) -> xr.DataArray: # TODO: [RA] DataArray + def clean_outliers(self, clip: float = 0.5) -> xr.DataArray: assert isinstance(self._obj, xr.DataArray) low, high = np.percentile(self._obj.values, [clip, 100 - clip]) copied = self._obj.copy(deep=True) @@ -2595,9 +2610,10 @@ def as_movie( self, time_dim: str = "delay", pattern: str = "{}.png", + *, out: str | bool = "", **kwargs: Unpack[PColorMeshKwargs], - ) -> Path | animation.FuncAnimation: # TODO: [RA] DataArray + ) -> Path | animation.FuncAnimation: assert isinstance(self._obj, xr.DataArray) if isinstance(out, bool) and out is True: @@ -2610,7 +2626,7 @@ def map_axes( axes: list[str] | str, fn: Callable[[XrTypes, dict[str, float]], DataType], dtype: DTypeLike = None, - ) -> xr.DataArray: # TODO: [RA] DataArray + ) -> xr.DataArray: """[TODO:summary]. Args: @@ -2651,7 +2667,7 @@ def transform( dtype: DTypeLike = None, *args: Incomplete, **kwargs: Incomplete, - ) -> xr.DataArray: # TODO: DataArray + ) -> xr.DataArray: """Applies a vectorized operation across a subset of array axes. Transform has similar semantics to matrix multiplication, the dimensions of the @@ -2730,7 +2746,7 @@ def map( self, fn: Callable[[NDArray[np.float_], Any], NDArray[np.float_]], **kwargs: Incomplete, - ) -> xr.DataArray: # TODO: [RA]: DataArray + ) -> xr.DataArray: """[TODO:summary]. Args: @@ -2751,7 +2767,7 @@ def shift_by( # noqa: PLR0913 *, zero_nans: bool = True, shift_coords: bool = False, - ) -> xr.DataArray: # TODO: [RA] DataArray + ) -> xr.DataArray: """Data shift along the axis. For now we only support shifting by a one dimensional array @@ -2773,6 +2789,7 @@ def shift_by( # noqa: PLR0913 raise TypeError(msg) assert isinstance(self._obj, xr.DataArray) data = self._obj.copy(deep=True) + mean_shift: np.float_ | float = 0.0 if isinstance(other, xr.DataArray): assert len(other.dims) == 1 @@ -2820,7 +2837,7 @@ def shift_by( # noqa: PLR0913 return built_data - def __init__(self, xarray_obj: XrTypes) -> None: + def __init__(self, xarray_obj: xr.DataArray) -> None: self._obj = xarray_obj