From 857292c7592f188b6ba8c357422a88553e06827d Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sun, 22 Oct 2023 16:55:53 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20=20Add=20a=20helper=20function?= =?UTF-8?q?=20for=20select=5Faround=5Fdata=20to=20refactor=20it?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/analysis/path.py | 22 +++++++++---------- arpes/models/band.py | 2 +- arpes/plotting/path_tool.py | 4 ++-- arpes/xarray_extensions.py | 42 +++---------------------------------- 4 files changed, 16 insertions(+), 54 deletions(-) diff --git a/arpes/analysis/path.py b/arpes/analysis/path.py index baad16af..5d86c8bc 100644 --- a/arpes/analysis/path.py +++ b/arpes/analysis/path.py @@ -12,8 +12,6 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from arpes._typing import DataType - __all__ = ( "discretize_path", "select_along_path", @@ -51,10 +49,10 @@ def discretize_path( if isinstance(scaling, dict): scaling = np.array(scaling[d] for d in order) - def as_vec(ds): + def as_vec(ds: xr.Dataset) -> NDArray[np.float_]: return np.array([ds[k].item() for k in order]) - def distance(a, b): + def distance(a: xr.Dataset, b: xr.Dataset) -> float: return np.linalg.norm((as_vec(a) - as_vec(b)) * scaling) length = 0 @@ -98,14 +96,13 @@ def to_dataarray(name: str) -> xr.DataArray: @update_provenance("Select from data along a path") def select_along_path( path: xr.Dataset, - data: DataType, + data: xr.DataArray, radius: float = 0, n_points: int = 0, *, - fast: bool = True, scaling: float | xr.Dataset | dict[str, NDArray[np.float_]] | None = None, **kwargs: Incomplete, -) -> DataType: +) -> xr.DataArray: """Performs integration along a path. This functionally allows for performing a finite width @@ -117,11 +114,12 @@ def select_along_path( path: The path to select along. data: The data to select/interpolate from. radius: A number or dictionary of radii to use for the selection along different dimensions, - if none is provided reasonable values will be chosen. Alternatively, you can pass radii - via `{dim}_r` kwargs as well, i.e. 'eV_r' or 'kp_r' + if none is provided reasonable values will be chosen. Alternatively, you can pass + radii via `{dim}_r` kwargs as well, i.e. 'eV_r' or 'kp_r' n_points: The number of points to interpolate along the path, by default we will infer a - reasonable number from the radius parameter, if provided or inferred - fast: If fast is true, will use rectangular selections rather than ellipsoid ones + reasonable number from the radius parameter, if provided or inferred + scaling: + kwargs: Returns: The data selected along the path. @@ -130,6 +128,6 @@ def select_along_path( selections = [] for _, view in new_path.G.iterate_axis("index"): - selections.append(data.S.select_around(view, radius=radius, fast=fast, **kwargs)) + selections.append(data.S.select_around(view, radius=radius, **kwargs)) return xr.concat(selections, new_path.index) diff --git a/arpes/models/band.py b/arpes/models/band.py index de969415..57769212 100644 --- a/arpes/models/band.py +++ b/arpes/models/band.py @@ -1,4 +1,4 @@ -"""Rudimentary band analyis code.""" +"""Rudimentary band analysis code.""" from __future__ import annotations from typing import TYPE_CHECKING diff --git a/arpes/plotting/path_tool.py b/arpes/plotting/path_tool.py index 3bec2a24..bf234603 100644 --- a/arpes/plotting/path_tool.py +++ b/arpes/plotting/path_tool.py @@ -134,14 +134,14 @@ def click_main_image(event): ("Path", "path"), ] - def convert_to_xarray(): + def convert_to_xarray() -> dict[str, xr.dataset]: """Creates a Dataset consisting of one array for each path. For each of the paths, we will create a dataset which has an index dimension, and datavariables for each of the coordinate dimensions """ - def convert_single_path_to_xarray(points): + def convert_single_path_to_xarray(points) -> xr.Dataset: vars = {d: np.array([p[i] for p in points]) for i, d in enumerate(self.arr.dims)} coords = { "index": np.array(range(len(points))), diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index e8ea0aec..bcb783e9 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -410,8 +410,7 @@ def select_around_data( around the Fermi momentum. Args: - points: The set of points where the selection should be performed. If points provided - as xr.Dataset, the Dataset is converted to {"data_vars": values} + points: The set of points where the selection should be performed. radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized selection will be made as a compromise. mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum" @@ -433,24 +432,7 @@ def select_around_data( if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} - if isinstance(radius, float): - radius = {str(d): radius for d in points} - else: - collected_terms = {f"{k}_r" for k in points}.intersection( - set(kwargs.keys()), - ) - if collected_terms: - radius = { - str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) - for d in points - } - elif radius is None: - radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} - - assert isinstance(radius, dict) - radius = { - str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points - } + radius = self._radius(points, radius, **kwargs) logger.debug(f"iter(points.values()): {iter(points.values())}") @@ -541,25 +523,7 @@ def select_around( if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} logger.debug(f"points: {points}") - if isinstance(radius, float): - radius = {str(d): radius for d in points} - else: - collected_terms = {f"{k}_r" for k in points}.intersection( - set(kwargs.keys()), - ) - if collected_terms: - radius = { - str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) - for d in points - } - elif radius is None: - radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} - - assert isinstance(radius, dict) - radius = { - str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points - } - + radius = self._radius(points, radius, **kwargs) logger.debug(f"radius: {radius}") nearest_sel_params = {}