Skip to content

Commit

Permalink
🔨 Add a helper function for select_around_data to refactor it
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Oct 22, 2023
1 parent 1d5b593 commit 857292c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 54 deletions.
22 changes: 10 additions & 12 deletions arpes/analysis/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = (
"discretize_path",
"select_along_path",
Expand Down Expand Up @@ -51,10 +49,10 @@ def discretize_path(
if isinstance(scaling, dict):
scaling = np.array(scaling[d] for d in order)

def as_vec(ds):
def as_vec(ds: xr.Dataset) -> NDArray[np.float_]:
return np.array([ds[k].item() for k in order])

def distance(a, b):
def distance(a: xr.Dataset, b: xr.Dataset) -> float:
return np.linalg.norm((as_vec(a) - as_vec(b)) * scaling)

length = 0
Expand Down Expand Up @@ -98,14 +96,13 @@ def to_dataarray(name: str) -> xr.DataArray:
@update_provenance("Select from data along a path")
def select_along_path(
path: xr.Dataset,
data: DataType,
data: xr.DataArray,
radius: float = 0,
n_points: int = 0,
*,
fast: bool = True,
scaling: float | xr.Dataset | dict[str, NDArray[np.float_]] | None = None,
**kwargs: Incomplete,
) -> DataType:
) -> xr.DataArray:
"""Performs integration along a path.
This functionally allows for performing a finite width
Expand All @@ -117,11 +114,12 @@ def select_along_path(
path: The path to select along.
data: The data to select/interpolate from.
radius: A number or dictionary of radii to use for the selection along different dimensions,
if none is provided reasonable values will be chosen. Alternatively, you can pass radii
via `{dim}_r` kwargs as well, i.e. 'eV_r' or 'kp_r'
if none is provided reasonable values will be chosen. Alternatively, you can pass
radii via `{dim}_r` kwargs as well, i.e. 'eV_r' or 'kp_r'
n_points: The number of points to interpolate along the path, by default we will infer a
reasonable number from the radius parameter, if provided or inferred
fast: If fast is true, will use rectangular selections rather than ellipsoid ones
reasonable number from the radius parameter, if provided or inferred
scaling:
kwargs:
Returns:
The data selected along the path.
Expand All @@ -130,6 +128,6 @@ def select_along_path(

selections = []
for _, view in new_path.G.iterate_axis("index"):
selections.append(data.S.select_around(view, radius=radius, fast=fast, **kwargs))
selections.append(data.S.select_around(view, radius=radius, **kwargs))

return xr.concat(selections, new_path.index)
2 changes: 1 addition & 1 deletion arpes/models/band.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Rudimentary band analyis code."""
"""Rudimentary band analysis code."""
from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down
4 changes: 2 additions & 2 deletions arpes/plotting/path_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def click_main_image(event):
("Path", "path"),
]

def convert_to_xarray():
def convert_to_xarray() -> dict[str, xr.dataset]:
"""Creates a Dataset consisting of one array for each path.
For each of the paths, we will create a dataset which has an index dimension,
and datavariables for each of the coordinate dimensions
"""

def convert_single_path_to_xarray(points):
def convert_single_path_to_xarray(points) -> xr.Dataset:
vars = {d: np.array([p[i] for p in points]) for i, d in enumerate(self.arr.dims)}
coords = {
"index": np.array(range(len(points))),
Expand Down
42 changes: 3 additions & 39 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,7 @@ def select_around_data(
around the Fermi momentum.
Args:
points: The set of points where the selection should be performed. If points provided
as xr.Dataset, the Dataset is converted to {"data_vars": values}
points: The set of points where the selection should be performed.
radius: The radius of the selection in each coordinate. If dimensions are omitted, a
standard sized selection will be made as a compromise.
mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum"
Expand All @@ -433,24 +432,7 @@ def select_around_data(
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}

if isinstance(radius, float):
radius = {str(d): radius for d in points}
else:
collected_terms = {f"{k}_r" for k in points}.intersection(
set(kwargs.keys()),
)
if collected_terms:
radius = {
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}

assert isinstance(radius, dict)
radius = {
str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points
}
radius = self._radius(points, radius, **kwargs)

logger.debug(f"iter(points.values()): {iter(points.values())}")

Expand Down Expand Up @@ -541,25 +523,7 @@ def select_around(
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
logger.debug(f"points: {points}")
if isinstance(radius, float):
radius = {str(d): radius for d in points}
else:
collected_terms = {f"{k}_r" for k in points}.intersection(
set(kwargs.keys()),
)
if collected_terms:
radius = {
str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED))
for d in points
}
elif radius is None:
radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points}

assert isinstance(radius, dict)
radius = {
str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points
}

radius = self._radius(points, radius, **kwargs)
logger.debug(f"radius: {radius}")
nearest_sel_params = {}

Expand Down

0 comments on commit 857292c

Please sign in to comment.