Skip to content

Commit

Permalink
🎨 Tiny refactoring and fixing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Apr 27, 2024
1 parent 3dbcc2f commit 1774aae
Showing 1 changed file with 26 additions and 30 deletions.
56 changes: 26 additions & 30 deletions src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def select_around_data(
radius = radius or {}
if isinstance(points, tuple | list):
warnings.warn("Dangerous iterable points argument to `select_around`", stacklevel=2)
points = dict(zip(points, self._obj.dims, strict=True))
points = dict(zip(self._obj.dims, points, strict=True))
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
assert isinstance(points, dict)
Expand Down Expand Up @@ -555,7 +555,7 @@ def select_around_data(

def select_around(
self,
points: dict[Hashable, float] | xr.DataArray,
points: dict[Hashable, float] | xr.Dataset,
radius: dict[Hashable, float] | float,
*,
mode: Literal["sum", "mean"] = "sum",
Expand Down Expand Up @@ -593,7 +593,7 @@ def select_around(

if isinstance(points, tuple | list):
warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2)
points = dict(zip(points, self._obj.dims, strict=True))
points = dict(zip(self._obj.dims, points, strict=True))
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
logger.debug(f"points: {points}")
Expand Down Expand Up @@ -627,8 +627,8 @@ def select_around(
return selected.sum(list(radius.keys()))
return selected.mean(list(radius.keys()))

@staticmethod
def _radius(
self,
points: dict[Hashable, xr.DataArray] | dict[Hashable, float],
radius: float | dict[Hashable, float],
**kwargs: float,
Expand Down Expand Up @@ -1627,8 +1627,8 @@ def dict_to_html(d: Mapping[str, float | str]) -> str:
rows="".join([f"<tr><td>{k}</td><td>{v}</td></tr>" for k, v in d.items()]),
)

@staticmethod
def _repr_html_full_coords(
self,
coords: xr.Coordinates,
) -> str:
significant_coords = {}
Expand Down Expand Up @@ -2201,7 +2201,7 @@ def round_coordinates(

return rounded

def argmax_coords(self) -> dict[Hashable, float]:
def argmax_coords(self) -> dict[Hashable, float]: # TODO [RA]: DaraArray
"""Return dict representing the position for maximum value."""
assert isinstance(self._obj, xr.DataArray)
data: xr.DataArray = self._obj
Expand Down Expand Up @@ -2235,7 +2235,7 @@ def apply_over(
data.loc[selections] = transformed
return data

def to_unit_range(self, percentile: float | None = None) -> XrTypes:
def to_unit_range(self, percentile: float | None = None) -> XrTypes: # TODD [RA]: DataArray
assert isinstance(self._obj, xr.DataArray) # to work with np.percentile
if percentile is None:
norm = self._obj - self._obj.min()
Expand All @@ -2246,7 +2246,7 @@ def to_unit_range(self, percentile: float | None = None) -> XrTypes:
norm = self._obj - low
return norm / (high - low)

def drop_nan(self) -> xr.DataArray:
def drop_nan(self) -> xr.DataArray: # TODD [RA]: DataArray
assert isinstance(self._obj, xr.DataArray) # ._obj.values
assert len(self._obj.dims) == 1

Expand Down Expand Up @@ -2319,7 +2319,9 @@ def transform_coords(

return copied

def filter_vars(self, f: Callable[[Hashable, xr.DataArray], bool]) -> xr.Dataset:
def filter_vars(
self, f: Callable[[Hashable, xr.DataArray], bool]
) -> xr.Dataset: # TODO [RA]: Dataset only
assert isinstance(self._obj, xr.Dataset) # ._obj.data_vars
return xr.Dataset(
data_vars={k: v for k, v in self._obj.data_vars.items() if f(k, v)},
Expand Down Expand Up @@ -2355,7 +2357,7 @@ def coordinatize(self, as_coordinate_name: str | None = None) -> XrTypes:

return o

def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]:
def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]: # TODO: [RA] DataArray
"""Converts to a flat representation where the coordinate values are also present.
Extremely valuable for plotting a dataset with coordinates, X, Y and values Z(X,Y)
Expand Down Expand Up @@ -2387,7 +2389,7 @@ def meshgrid(
self,
*,
as_dataset: bool = False,
) -> dict[Hashable, NDArray[np.float_]] | xr.Dataset:
) -> dict[Hashable, NDArray[np.float_]] | xr.Dataset: # TODO: [RA] DataArray
assert isinstance(self._obj, xr.DataArray) # ._obj.values is used.

dims = self._obj.dims
Expand All @@ -2409,7 +2411,7 @@ def meshgrid(

return meshed_coordinates

def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]:
def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]: # TODO: [RA] DataArray
"""Converts a (1D) `xr.DataArray` into two plain ``ndarray``s of their coordinate and data.
Useful for rapidly converting into a format than can be `plt.scatter`ed
Expand All @@ -2428,7 +2430,7 @@ def to_arrays(self) -> tuple[NDArray[np.float_], NDArray[np.float_]]:

return (self._obj.coords[self._obj.dims[0]].values, self._obj.values)

def clean_outliers(self, clip: float = 0.5) -> xr.DataArray:
def clean_outliers(self, clip: float = 0.5) -> xr.DataArray: # TODO: [RA] DataArray
assert isinstance(self._obj, xr.DataArray)
low, high = np.percentile(self._obj.values, [clip, 100 - clip])
copied = self._obj.copy(deep=True)
Expand All @@ -2442,7 +2444,7 @@ def as_movie(
pattern: str = "{}.png",
out: str | bool = "",
**kwargs: Unpack[PColorMeshKwargs],
) -> Path | animation.FuncAnimation:
) -> Path | animation.FuncAnimation: # TODO: [RA] DataArray
assert isinstance(self._obj, xr.DataArray)

if isinstance(out, bool) and out is True:
Expand Down Expand Up @@ -2511,7 +2513,7 @@ def map_axes(
axes: list[str] | str,
fn: Callable[[XrTypes, dict[str, float]], DataType],
dtype: DTypeLike = None,
) -> xr.DataArray:
) -> xr.DataArray: # TODO: [RA] DataArray
"""[TODO:summary].
Args:
Expand All @@ -2522,13 +2524,9 @@ def map_axes(
Raises:
TypeError: [TODO:description]
"""
if isinstance(self._obj, xr.Dataset):
msg = "map_axes can only work on xr.DataArrays for now because of how the type"
msg += " inference works"
raise TypeError(
msg,
)
assert isinstance(self._obj, xr.DataArray)
msg = "map_axes can only work on xr.DataArrays for now because of how the type"
msg += " inference works"
assert isinstance(self._obj, xr.DataArray), msg
obj = self._obj.copy(deep=True)

if dtype is not None:
Expand Down Expand Up @@ -2556,7 +2554,7 @@ def transform(
dtype: DTypeLike = None,
*args: Incomplete,
**kwargs: Incomplete,
) -> xr.DataArray:
) -> xr.DataArray: # TODO: DataArray
"""Applies a vectorized operation across a subset of array axes.
Transform has similar semantics to matrix multiplication, the dimensions of the
Expand Down Expand Up @@ -2599,12 +2597,10 @@ def transform(
The data consisting of applying `transform_fn` across the specified axes.
"""
if isinstance(self._obj, xr.Dataset):
msg = "transform can only work on xr.DataArrays for"
msg += " now because of how the type inference works"
raise TypeError(msg)
msg = "transform can only work on xr.DataArrays for"
msg += " now because of how the type inference works"

assert isinstance(self._obj, xr.DataArray)
assert isinstance(self._obj, xr.DataArray), msg
dest = None
for coord, value in self.iterate_axis(axes):
new_value = transform_fn(value, coord, *args, **kwargs)
Expand Down Expand Up @@ -2637,7 +2633,7 @@ def map(
self,
fn: Callable[[NDArray[np.float_], Any], NDArray[np.float_]],
**kwargs: Incomplete,
) -> xr.DataArray:
) -> xr.DataArray: # TODO: [RA]: DataArray
"""[TODO:summary].
Args:
Expand Down Expand Up @@ -2752,7 +2748,7 @@ def shift_by( # noqa: PLR0913
*,
zero_nans: bool = True,
shift_coords: bool = False,
) -> xr.DataArray:
) -> xr.DataArray: # TODO: [RA] DataArray
"""Data shift along the axis.
For now we only support shifting by a one dimensional array
Expand Down

0 comments on commit 1774aae

Please sign in to comment.