Skip to content

Commit

Permalink
🔨 Refactoring "select_around_data" and "select_around"
Browse files Browse the repository at this point in the history
    * Drop "fast" argument.
        * The circular selection does not make sence in most case, and it has not been implemented.
        * point -> points

📝  add docstring in G.stride and G.range
  • Loading branch information
arafune committed Oct 20, 2023
1 parent f6f9f79 commit 96687f0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 81 deletions.
163 changes: 83 additions & 80 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@

ANGLE_VARS = ("alpha", "beta", "chi", "psi", "phi", "theta")

DEFAULT_RADII = {
"kp": 0.02,
"kz": 0.05,
"phi": 0.02,
"beta": 0.02,
"theta": 0.02,
"eV": 0.05,
"delay": 0.2,
"T": 2,
"temperature": 2,
}

UNSPESIFIED = 0.1

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = getLogger(__name__)
Expand Down Expand Up @@ -373,10 +387,9 @@ def transpose_to_back(self, dim: str) -> xr.DataArray | xr.Dataset:

def select_around_data(
self,
points: dict[str, Any] | xr.Dataset,
points: dict[str, float] | xr.Dataset,
radius: dict[str, float] | float | None = None, # radius={"phi": 0.005}
*,
fast: bool = False,
mode: Literal["sum", "mean"] = "sum",
**kwargs: Incomplete,
) -> xr.DataArray:
Expand All @@ -400,8 +413,6 @@ def select_around_data(
points: The set of points where the selection should be performed.
radius: The radius of the selection in each coordinate. If dimensions are omitted, a
standard sized selection will be made as a compromise.
fast: If true, uses a rectangular rather than a circular region for selectioIf true,
uses a rectangular rather than a circular region for selection.
mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum"
kwargs: Can be used to pass radii parameters by keyword with `_r` postfix.
Expand All @@ -421,20 +432,6 @@ def select_around_data(
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}

default_radii = {
"kp": 0.02,
"kz": 0.05,
"phi": 0.02,
"beta": 0.02,
"theta": 0.02,
"eV": 0.05,
"delay": 0.2,
"T": 2,
"temperature": 2,
}

UNSPESIFIED = 0.1

if isinstance(radius, float):
radius = {str(d): radius for d in points}
else:
Expand All @@ -443,15 +440,15 @@ def select_around_data(
)
if collected_terms:
radius = {
str(d): kwargs.get(f"{d}_r", default_radii.get(str(d), UNSPESIFIED))
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): default_radii.get(str(d), UNSPESIFIED) for d in points}
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}

assert isinstance(radius, dict)
radius = {
str(d): radius.get(str(d), default_radii.get(str(d), UNSPESIFIED)) for d in points
str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points
}

along_dims = next(iter(points.values())).dims
Expand All @@ -473,18 +470,15 @@ def select_around_data(
radius = {d: v for d, v in radius.items() if d not in nearest_sel_params}
# -- to heari, but as name said, should be alwayws safe.

if fast:
selection_slices = {
d: slice(
points[d].sel(**coord) - radius[d],
points[d].sel(**coord) + radius[d],
)
for d in points
if d in radius
}
selected = value.sel(**selection_slices)
else:
raise NotImplementedError
selection_slices = {
d: slice(
points[d].sel(**coord) - radius[d],
points[d].sel(**coord) + radius[d],
)
for d in points
if d in radius
}
selected = value.sel(**selection_slices)

if nearest_sel_params:
selected = selected.sel(**nearest_sel_params, method="nearest")
Expand All @@ -502,10 +496,9 @@ def select_around_data(

def select_around(
self,
point: dict[str, Any] | xr.Dataset,
points: dict[str, float] | xr.Dataset,
radius: dict[str, float] | float | None = None,
*,
fast: bool = False,
mode: Literal["sum", "mean"] = "sum",
**kwargs: Incomplete,
) -> xr.DataArray:
Expand All @@ -522,11 +515,9 @@ def select_around(
then we will try to use reasonable default values; buyer beware.
Args:
point: The points where the selection should be performed.
points: The points where the selection should be performed.
radius: The radius of the selection in each coordinate. If dimensions are omitted, a
standard sized selection will be made as a compromise.
fast: If true, uses a rectangular rather than a circular region for selectioIf true,
uses a rectangular rather than a circular region for selection.
safe: If true, infills radii with default values. Defaults to `True`.
mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum"
**kwargs: Can be used to pass radii parameters by keyword with `_r` postfix.
Expand All @@ -538,62 +529,50 @@ def select_around(
self._obj,
xr.DataArray,
), "Cannot use select_around on Datasets only DataArrays!"
assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean."
if isinstance(point, tuple | list):
warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2)
point = dict(zip(point, self._obj.dims, strict=True))
if isinstance(point, xr.Dataset):
point = {k: point[k].item() for k in point.data_vars}

default_radii = {
"kp": 0.02,
"kz": 0.05,
"phi": 0.02,
"beta": 0.02,
"theta": 0.02,
"eV": 0.05,
"delay": 0.2,
"T": 2,
}

unspecified = 0.1
assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean."

if isinstance(points, tuple | list):
warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2)
points = dict(zip(points, self._obj.dims, strict=True))
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
logger.debug(f"points: {points}")
if isinstance(radius, float):
radius = {str(d): radius for d in point}
radius = {str(d): radius for d in points}
else:
collected_terms = {f"{k}_r" for k in point}.intersection(
collected_terms = {f"{k}_r" for k in points}.intersection(
set(kwargs.keys()),
)
if collected_terms:
radius = {
str(d): kwargs.get(f"{d}_r", default_radii.get(str(d), unspecified))
for d in point
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.Get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): default_radii.get(str(d), unspecified) for d in point}
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}

assert isinstance(radius, dict)
radius = {str(d): radius.get(str(d), default_radii.get(str(d), unspecified)) for d in point}
radius = {
str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points
}

# make sure we are taking at least one pixel along each
logger.debug(f"radius: {radius}")
nearest_sel_params = {}

# -- originally, if safe == True, the following liens starting from hear
stride = self._obj.G.stride(generic_dim_names=False)
for d, v in radius.items():
if v < stride[d]:
nearest_sel_params[d] = point[d]
nearest_sel_params[d] = points[d]

radius = {d: v for d, v in radius.items() if d not in nearest_sel_params}
# -- to heari, but as name said, should be alwayws safe.

if fast:
selection_slices = {
d: slice(point[d] - radius[d], point[d] + radius[d]) for d in point if d in radius
}
selected = self._obj.sel(**selection_slices)
else:
raise NotImplementedError
selection_slices = {
d: slice(points[d] - radius[d], points[d] + radius[d]) for d in points if d in radius
}
selected = self._obj.sel(**selection_slices)

if nearest_sel_params:
selected = selected.sel(**nearest_sel_params, method="nearest")
Expand All @@ -604,9 +583,7 @@ def select_around(

if mode == "sum":
return selected.sum(list(radius.keys()))
if mode == "mean":
return selected.mean(list(radius.keys()))
raise RuntimeError
return selected.mean(list(radius.keys()))

def short_history(self, key: str = "by") -> list:
"""Return the short version of history.
Expand Down Expand Up @@ -2853,7 +2830,19 @@ def iter_coords(
for ts in itertools.product(*[self._obj.coords[d].values for d in dim_names]):
yield dict(zip(dim_names, ts, strict=True))

def range(self, *, generic_dim_names: bool = True) -> dict:
def range(
self,
*,
generic_dim_names: bool = True,
) -> dict[str, tuple[float, float]]:
"""Return the maximum/minimum value in each dimension.
Args:
generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used.
Returns: (dict[str, tuple[float, float]])
The range of each dimension.
"""
assert isinstance(self._obj, xr.DataArray | xr.Dataset)
indexed_coords = [self._obj.coords[d] for d in self._obj.dims]
indexed_ranges = [(np.min(coord.values), np.max(coord.values)) for coord in indexed_coords]
Expand All @@ -2864,7 +2853,24 @@ def range(self, *, generic_dim_names: bool = True) -> dict:

return dict(zip(dim_names, indexed_ranges, strict=True))

def stride(self, *args: Incomplete, generic_dim_names: bool = True) -> dict | list:
def stride(
self,
*args: str | list[str] | tuple[str, ...],
generic_dim_names: bool = True,
) -> dict[str, float] | list[float] | float:
"""Return the stride in each dimension.
Note that the stride defined in this method is just a difference between first two values.
In most case, this treatment does not cause a problem. However, when the data has been
concatenated, this assumption may not be not valid.
Args:
args: The dimension to return. ["eV", "phi"] or "eV", "phi"
generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used.
Returns:
The stride of each dimension
"""
assert isinstance(self._obj, xr.DataArray | xr.Dataset)
indexed_coords: list[xr.DataArray] = [self._obj.coords[d] for d in self._obj.dims]
indexed_strides: list[float] = [
Expand All @@ -2876,14 +2882,11 @@ def stride(self, *args: Incomplete, generic_dim_names: bool = True) -> dict | li
dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)]

result = dict(zip(dim_names, indexed_strides, strict=True))

if args:
if len(args) == 1:
if not isinstance(args[0], str):
# if passed list of strs as argument
if not isinstance(args[0], str): # suppose args is list / tuple
result = [result[selected_names] for selected_names in args[0]]
else:
# if passed single name as argument
result = result[args[0]]
else:
# if passed several names as arguments
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/custom-dot-s-functionality.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@
"# take the Lorentzian (component `b`) center parameter\n",
"phi_values = phis.F.p(\"b_center\")\n",
"temp_dep.spectrum.S.select_around_data(\n",
" {\"phi\": phi_values}, mode=\"mean\", fast=True, radius={\"phi\": 0.005}\n",
" {\"phi\": phi_values}, mode=\"mean\", radius={\"phi\": 0.005}\n",
").S.plot()"
]
},
Expand Down

0 comments on commit 96687f0

Please sign in to comment.