Skip to content

Commit

Permalink
🔨 Refactoring: Devide method in "G" into three classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 29, 2024
1 parent 5948f65 commit 8667aca
Showing 1 changed file with 51 additions and 34 deletions.
85 changes: 51 additions & 34 deletions src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def along(
self,
directions: list[Hashable | dict[Hashable, float]],
**kwargs: Unpack[_SliceAlongPathKwags],
) -> xr.Dataset:
) -> xr.Dataset: # TODO: [RA] xr.DataArray
"""[TODO:summary].
Args:
Expand Down Expand Up @@ -210,7 +210,7 @@ def sherman_function(self) -> float:
Raises: ValueError
When no Sherman function related value is found.
ToDo: Test
ToDo: Test, Consider if it should be in "S"
"""
for option in ["sherman", "sherman_function", "SHERMAN"]:
if option in self._obj.attrs:
Expand Down Expand Up @@ -254,7 +254,7 @@ def polarization(self) -> float | str | tuple[float, float]:
return np.nan

@property
def is_subtracted(self) -> bool:
def is_subtracted(self) -> bool: # TODO: [RA] xr.DataArray
"""Infers whether a given data is subtracted.
Returns (bool):
Expand All @@ -275,6 +275,7 @@ def is_spatial(self) -> bool:
True if the data is explicltly a "ucut" or "spem" or contains "X", "Y", or "Z"
dimensions. False otherwise.
"""
assert isinstance(self._obj, xr.DataArray | xr.Dataset)
if self.spectrum_type in {"ucut", "spem"}:
return True

Expand All @@ -294,7 +295,7 @@ def is_kspace(self) -> bool:
return not any(d in {"phi", "theta", "beta", "angle"} for d in self._obj.dims)

@property
def is_slit_vertical(self) -> bool:
def is_slit_vertical(self) -> bool: # TODO: [RA] Refactoring ?
"""Infers whether the scan is taken on an analyzer with vertical slit.
Caveat emptor: this assumes that the alpha coordinate is not some intermediate value.
Expand All @@ -320,7 +321,10 @@ def endstation(self) -> str:
"""
return str(self._obj.attrs["location"])

def with_values(self, new_values: NDArray[np.float_]) -> xr.DataArray:
def with_values(
self,
new_values: NDArray[np.float_],
) -> xr.DataArray: # TODO: [RA] xr.DataArray
"""Copy with new array values.
Easy way of creating a DataArray that has the same shape as the calling object but data
Expand Down Expand Up @@ -497,10 +501,8 @@ def select_around_data(
), "Cannot use select_around on Datasets only DataArrays!"

assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean."
assert isinstance(points, dict | xr.Dataset)
radius = radius or {}
if isinstance(points, tuple | list):
warnings.warn("Dangerous iterable points argument to `select_around`", stacklevel=2)
points = dict(zip(self._obj.dims, points, strict=True))
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
assert isinstance(points, dict)
Expand Down Expand Up @@ -590,10 +592,7 @@ def select_around(
), "Cannot use select_around on Datasets only DataArrays!"

assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean."

if isinstance(points, tuple | list):
warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2)
points = dict(zip(self._obj.dims, points, strict=True))
assert isinstance(points, dict | xr.Dataset)
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
logger.debug(f"points: {points}")
Expand Down Expand Up @@ -856,7 +855,11 @@ def inner_potential(self) -> float:
return self._obj.attrs["inner_potential"]
return 10

def find_spectrum_energy_edges(self, *, indices: bool = False) -> NDArray[np.float_]:
def find_spectrum_energy_edges(
self,
*,
indices: bool = False,
) -> NDArray[np.float_]: # TODO: xr.DataArray
"""Return energy position corresponding to the (1D) spectrum edge.
Spectrum edge is infection point of the peak.
Expand Down Expand Up @@ -962,7 +965,7 @@ def zero_spectrometer_edges(
interp_range: float | None = None,
low: Sequence[float] | NDArray[np.float_] | None = None,
high: Sequence[float] | NDArray[np.float_] | None = None,
) -> xr.DataArray:
) -> xr.DataArray: # TODO: [RA] xr.DataArray
assert isinstance(self._obj, xr.DataArray)
if low is not None:
assert high is not None
Expand Down Expand Up @@ -1047,7 +1050,7 @@ def find_spectrum_angular_edges(
*,
angle_name: str = "phi",
indices: bool = False,
) -> NDArray[np.float_] | NDArray[np.int_]:
) -> NDArray[np.float_] | NDArray[np.int_]: # TODO: [RA] xr.DataArray
"""Return angle position corresponding to the (1D) spectrum edge.
Args:
Expand Down Expand Up @@ -1877,11 +1880,11 @@ def fs_plot(
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | tuple[Figure | None, Axes]:
"""Provides a reference plot of the approximate Fermi surface."""
assert isinstance(self._obj, xr.DataArray)
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = pattern.format(f"{self.label}_fs")
kwargs["out"] = out
assert isinstance(self._obj, xr.DataArray)
return labeled_fermi_surface(self._obj, **kwargs)

def fermi_edge_reference_plot(
Expand All @@ -1900,9 +1903,9 @@ def fermi_edge_reference_plot(
Returns:
[TODO:description]
"""
assert isinstance(self._obj, xr.DataArray)
if out is not None and isinstance(out, bool):
out = pattern.format(f"{self.label}_fermi_edge_reference")
assert isinstance(self._obj, xr.DataArray)
return fermi_edge_reference(self._obj, out=out, **kwargs)

def _referenced_scans_for_spatial_plot(
Expand Down Expand Up @@ -2179,9 +2182,7 @@ def correct_angle_by(
NORMALIZED_DIM_NAMES = ["x", "y", "z", "w"]


@xr.register_dataset_accessor("G")
@xr.register_dataarray_accessor("G")
class GenericAccessorTools:
class GenericAccessorBase:
_obj: XrTypes

def round_coordinates(
Expand Down Expand Up @@ -2488,19 +2489,33 @@ def iterate_axis(
coords_dict = dict(zip(axis_name_or_axes, cut_coords, strict=True))
yield coords_dict, self._obj.sel(coords_dict, method="nearest")

# ---------

@xr.register_dataset_accessor("G")
class GenericDatasetAccessor(GenericAccessorBase):
def filter_vars(
self,
f: Callable[[Hashable, xr.DataArray], bool],
) -> xr.Dataset: # TODO: [RA] Dataset only
) -> xr.Dataset:
assert isinstance(self._obj, xr.Dataset) # ._obj.data_vars
return xr.Dataset(
data_vars={k: v for k, v in self._obj.data_vars.items() if f(k, v)},
attrs=self._obj.attrs,
)

# ----------
def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray
def __init__(self, xarray_obj: xr.Dataset) -> None:
"""Initialization hook for xarray.
This should never need to be called directly.
Args:
xarray_obj: The parent object which this is an accessor for
"""
self._obj = xarray_obj


@xr.register_dataarray_accessor("G")
class GenericDataArrayAccessor(GenericAccessorBase):
def argmax_coords(self) -> dict[Hashable, float]:
"""Return dict representing the position for maximum value."""
assert isinstance(self._obj, xr.DataArray)
data: xr.DataArray = self._obj
Expand All @@ -2510,7 +2525,7 @@ def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray
flat_indices = np.unravel_index(idx, data.values.shape)
return {d: data.coords[d][flat_indices[i]].item() for i, d in enumerate(data.dims)}

def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]: # TODO: [RA] DataArray
def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]:
"""Converts to a flat representation where the coordinate values are also present.
Extremely valuable for plotting a dataset with coordinates, X, Y and values Z(X,Y)
Expand Down Expand Up @@ -2542,7 +2557,7 @@ def meshgrid(
self,
*,
as_dataset: bool = False,
) -> dict[Hashable, NDArray[np.float_]] | xr.Dataset: # TODO: [RA] DataArray
) -> dict[Hashable, NDArray[np.float_]] | xr.Dataset:
assert isinstance(self._obj, xr.DataArray) # ._obj.values is used.

dims = self._obj.dims
Expand All @@ -2564,7 +2579,7 @@ def meshgrid(

return meshed_coordinates

def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]: # TODO: [RA] DataArray
def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]:
"""Converts a (1D) `xr.DataArray` into two plain ``ndarray``s of their coordinate and data.
Useful for rapidly converting into a format than can be `plt.scatter`ed
Expand All @@ -2583,7 +2598,7 @@ def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]: # TODO: [

return (self._obj.coords[self._obj.dims[0]].values, self._obj.values)

def clean_outliers(self, clip: float = 0.5) -> xr.DataArray: # TODO: [RA] DataArray
def clean_outliers(self, clip: float = 0.5) -> xr.DataArray:
assert isinstance(self._obj, xr.DataArray)
low, high = np.percentile(self._obj.values, [clip, 100 - clip])
copied = self._obj.copy(deep=True)
Expand All @@ -2595,9 +2610,10 @@ def as_movie(
self,
time_dim: str = "delay",
pattern: str = "{}.png",
*,
out: str | bool = "",
**kwargs: Unpack[PColorMeshKwargs],
) -> Path | animation.FuncAnimation: # TODO: [RA] DataArray
) -> Path | animation.FuncAnimation:
assert isinstance(self._obj, xr.DataArray)

if isinstance(out, bool) and out is True:
Expand All @@ -2610,7 +2626,7 @@ def map_axes(
axes: list[str] | str,
fn: Callable[[XrTypes, dict[str, float]], DataType],
dtype: DTypeLike = None,
) -> xr.DataArray: # TODO: [RA] DataArray
) -> xr.DataArray:
"""[TODO:summary].
Args:
Expand Down Expand Up @@ -2651,7 +2667,7 @@ def transform(
dtype: DTypeLike = None,
*args: Incomplete,
**kwargs: Incomplete,
) -> xr.DataArray: # TODO: DataArray
) -> xr.DataArray:
"""Applies a vectorized operation across a subset of array axes.
Transform has similar semantics to matrix multiplication, the dimensions of the
Expand Down Expand Up @@ -2730,7 +2746,7 @@ def map(
self,
fn: Callable[[NDArray[np.float_], Any], NDArray[np.float_]],
**kwargs: Incomplete,
) -> xr.DataArray: # TODO: [RA]: DataArray
) -> xr.DataArray:
"""[TODO:summary].
Args:
Expand All @@ -2751,7 +2767,7 @@ def shift_by( # noqa: PLR0913
*,
zero_nans: bool = True,
shift_coords: bool = False,
) -> xr.DataArray: # TODO: [RA] DataArray
) -> xr.DataArray:
"""Data shift along the axis.
For now we only support shifting by a one dimensional array
Expand All @@ -2773,6 +2789,7 @@ def shift_by( # noqa: PLR0913
raise TypeError(msg)
assert isinstance(self._obj, xr.DataArray)
data = self._obj.copy(deep=True)
mean_shift: np.float_ | float = 0.0

if isinstance(other, xr.DataArray):
assert len(other.dims) == 1
Expand Down Expand Up @@ -2820,7 +2837,7 @@ def shift_by( # noqa: PLR0913

return built_data

def __init__(self, xarray_obj: XrTypes) -> None:
def __init__(self, xarray_obj: xr.DataArray) -> None:
self._obj = xarray_obj


Expand Down

0 comments on commit 8667aca

Please sign in to comment.