From 744482deb2ce03772f0efb024409d8fc903c221e Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 20 Oct 2023 16:08:18 +0900 Subject: [PATCH 1/4] =?UTF-8?q?=F0=9F=94=A8=20=20add=20unit=20test=20for?= =?UTF-8?q?=20select=5Faround?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_xarray_extensions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_xarray_extensions.py b/tests/test_xarray_extensions.py index faf64969..f0afd69c 100644 --- a/tests/test_xarray_extensions.py +++ b/tests/test_xarray_extensions.py @@ -113,6 +113,25 @@ def test_short_history(self, dataarray_cut: xr.DataArray) -> None: assert history[1] == "filesystem" +def test_select_around(dataarray_cut: xr.DataArray) -> None: + """Test for select_around.""" + data_1 = dataarray_cut.S.select_around(points={"phi": 0.30}, radius={"phi": 0.05}).values + data_2 = dataarray_cut.sel(phi=slice(0.25, 0.35)).sum("phi").values + np.testing.assert_almost_equal(data_1, data_2) + # + data_1 = dataarray_cut.S.select_around( + points={"phi": 0.30}, + radius={"phi": 0.05}, + mode="mean", + ).values + data_2 = dataarray_cut.sel(phi=slice(0.25, 0.35)).mean("phi").values + np.testing.assert_almost_equal(data_1, data_2) + # + data_1 = dataarray_cut.S.select_around(points={"phi": 0.30}, radius={"phi": 0.000001}).values + data_2 = dataarray_cut.sel(phi=0.3, method="nearest").values + np.testing.assert_almost_equal(data_1, data_2) + + def test_find(dataarray_cut: xr.DataArray) -> None: """Test for S.find.""" assert sorted(dataarray_cut.S.find("offset")) == sorted( From 1d5b593cf9782a309e5835a4c840a383d45c2d1f Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 20 Oct 2023 16:39:49 +0900 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=94=A8=20=20add=20helper=20function:?= =?UTF-8?q?=20=5Fradius?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/xarray_extensions.py | 41 +++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 2aaf4885..e8ea0aec 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -387,7 +387,7 @@ def transpose_to_back(self, dim: str) -> xr.DataArray | xr.Dataset: def select_around_data( self, - points: dict[str, float] | xr.Dataset, + points: dict[str, xr.DataArray] | xr.Dataset, radius: dict[str, float] | float | None = None, # radius={"phi": 0.005} *, mode: Literal["sum", "mean"] = "sum", @@ -410,7 +410,8 @@ def select_around_data( around the Fermi momentum. Args: - points: The set of points where the selection should be performed. + points: The set of points where the selection should be performed. If points provided + as xr.Dataset, the Dataset is converted to {"data_vars": values} radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized selection will be made as a compromise. mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum" @@ -451,6 +452,8 @@ def select_around_data( str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points } + logger.debug(f"iter(points.values()): {iter(points.values())}") + along_dims = next(iter(points.values())).dims selected_dims = list(points.keys()) @@ -546,7 +549,7 @@ def select_around( ) if collected_terms: radius = { - str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.Get(str(d), UNSPESIFIED)) + str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points } elif radius is None: @@ -585,6 +588,38 @@ def select_around( return selected.sum(list(radius.keys())) return selected.mean(list(radius.keys())) + def _radius( + self, + points: dict[str, float], + radius: float | dict[str, float] | None, + **kwargs: float, + ) -> dict[str, float]: + """Helper function. Generate radius dict. + + When radius is dict form, nothing has been done, essentially. + + Args: + points (dict[str, float]): Selection point + radius (dict[str, float] | float | None): radius + kwargs (float): [TODO:description] + + Returns: dict[str, float] + radius for selection. + """ + if isinstance(radius, float): + radius = {str(d): radius for d in points} + else: + collectted_terms = {f"{k}_r" for k in points}.intersection(set(kwargs.keys())) + if collectted_terms: + radius = { + str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) + for d in points + } + elif radius is None: + radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} + assert isinstance(radius, dict) + return {str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points} + def short_history(self, key: str = "by") -> list: """Return the short version of history. From 857292c7592f188b6ba8c357422a88553e06827d Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sun, 22 Oct 2023 16:55:53 +0900 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=94=A8=20=20Add=20a=20helper=20functi?= =?UTF-8?q?on=20for=20select=5Faround=5Fdata=20to=20refactor=20it?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/analysis/path.py | 22 +++++++++---------- arpes/models/band.py | 2 +- arpes/plotting/path_tool.py | 4 ++-- arpes/xarray_extensions.py | 42 +++---------------------------------- 4 files changed, 16 insertions(+), 54 deletions(-) diff --git a/arpes/analysis/path.py b/arpes/analysis/path.py index baad16af..5d86c8bc 100644 --- a/arpes/analysis/path.py +++ b/arpes/analysis/path.py @@ -12,8 +12,6 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from arpes._typing import DataType - __all__ = ( "discretize_path", "select_along_path", @@ -51,10 +49,10 @@ def discretize_path( if isinstance(scaling, dict): scaling = np.array(scaling[d] for d in order) - def as_vec(ds): + def as_vec(ds: xr.Dataset) -> NDArray[np.float_]: return np.array([ds[k].item() for k in order]) - def distance(a, b): + def distance(a: xr.Dataset, b: xr.Dataset) -> float: return np.linalg.norm((as_vec(a) - as_vec(b)) * scaling) length = 0 @@ -98,14 +96,13 @@ def to_dataarray(name: str) -> xr.DataArray: @update_provenance("Select from data along a path") def select_along_path( path: xr.Dataset, - data: DataType, + data: xr.DataArray, radius: float = 0, n_points: int = 0, *, - fast: bool = True, scaling: float | xr.Dataset | dict[str, NDArray[np.float_]] | None = None, **kwargs: Incomplete, -) -> DataType: +) -> xr.DataArray: """Performs integration along a path. This functionally allows for performing a finite width @@ -117,11 +114,12 @@ def select_along_path( path: The path to select along. data: The data to select/interpolate from. radius: A number or dictionary of radii to use for the selection along different dimensions, - if none is provided reasonable values will be chosen. Alternatively, you can pass radii - via `{dim}_r` kwargs as well, i.e. 'eV_r' or 'kp_r' + if none is provided reasonable values will be chosen. Alternatively, you can pass + radii via `{dim}_r` kwargs as well, i.e. 'eV_r' or 'kp_r' n_points: The number of points to interpolate along the path, by default we will infer a - reasonable number from the radius parameter, if provided or inferred - fast: If fast is true, will use rectangular selections rather than ellipsoid ones + reasonable number from the radius parameter, if provided or inferred + scaling: + kwargs: Returns: The data selected along the path. @@ -130,6 +128,6 @@ def select_along_path( selections = [] for _, view in new_path.G.iterate_axis("index"): - selections.append(data.S.select_around(view, radius=radius, fast=fast, **kwargs)) + selections.append(data.S.select_around(view, radius=radius, **kwargs)) return xr.concat(selections, new_path.index) diff --git a/arpes/models/band.py b/arpes/models/band.py index de969415..57769212 100644 --- a/arpes/models/band.py +++ b/arpes/models/band.py @@ -1,4 +1,4 @@ -"""Rudimentary band analyis code.""" +"""Rudimentary band analysis code.""" from __future__ import annotations from typing import TYPE_CHECKING diff --git a/arpes/plotting/path_tool.py b/arpes/plotting/path_tool.py index 3bec2a24..bf234603 100644 --- a/arpes/plotting/path_tool.py +++ b/arpes/plotting/path_tool.py @@ -134,14 +134,14 @@ def click_main_image(event): ("Path", "path"), ] - def convert_to_xarray(): + def convert_to_xarray() -> dict[str, xr.dataset]: """Creates a Dataset consisting of one array for each path. For each of the paths, we will create a dataset which has an index dimension, and datavariables for each of the coordinate dimensions """ - def convert_single_path_to_xarray(points): + def convert_single_path_to_xarray(points) -> xr.Dataset: vars = {d: np.array([p[i] for p in points]) for i, d in enumerate(self.arr.dims)} coords = { "index": np.array(range(len(points))), diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index e8ea0aec..bcb783e9 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -410,8 +410,7 @@ def select_around_data( around the Fermi momentum. Args: - points: The set of points where the selection should be performed. If points provided - as xr.Dataset, the Dataset is converted to {"data_vars": values} + points: The set of points where the selection should be performed. radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized selection will be made as a compromise. mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum" @@ -433,24 +432,7 @@ def select_around_data( if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} - if isinstance(radius, float): - radius = {str(d): radius for d in points} - else: - collected_terms = {f"{k}_r" for k in points}.intersection( - set(kwargs.keys()), - ) - if collected_terms: - radius = { - str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) - for d in points - } - elif radius is None: - radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} - - assert isinstance(radius, dict) - radius = { - str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points - } + radius = self._radius(points, radius, **kwargs) logger.debug(f"iter(points.values()): {iter(points.values())}") @@ -541,25 +523,7 @@ def select_around( if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} logger.debug(f"points: {points}") - if isinstance(radius, float): - radius = {str(d): radius for d in points} - else: - collected_terms = {f"{k}_r" for k in points}.intersection( - set(kwargs.keys()), - ) - if collected_terms: - radius = { - str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) - for d in points - } - elif radius is None: - radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} - - assert isinstance(radius, dict) - radius = { - str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points - } - + radius = self._radius(points, radius, **kwargs) logger.debug(f"radius: {radius}") nearest_sel_params = {} From d5c62dd573d5a7d439f05673327d4ca87e2a2836 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Mon, 23 Oct 2023 09:12:30 +0900 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=94=A8=20=20A=20refactoriing=20?= =?UTF-8?q?=F0=9F=92=AC=20=20update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/bootstrap.py | 9 ++++++-- arpes/io.py | 30 +++++++++++++++++++------- arpes/preparation/axis_preparation.py | 8 +++---- arpes/preparation/coord_preparation.py | 6 ++++-- arpes/preparation/tof_preparation.py | 6 ++---- arpes/utilities/geometry.py | 13 ++++++++--- arpes/widgets.py | 25 ++++++--------------- 7 files changed, 55 insertions(+), 42 deletions(-) diff --git a/arpes/bootstrap.py b/arpes/bootstrap.py index 9c04498e..5238e4d2 100644 --- a/arpes/bootstrap.py +++ b/arpes/bootstrap.py @@ -242,7 +242,7 @@ def from_param(cls: type, model_param: lf.Model.Parameter): return cls(center=model_param.value, stderr=model_param.stderr) -def propagate_errors(f) -> Callable: +def propagate_errors(f: Callable) -> Callable: """A decorator which provides transparent propagation of statistical errors. The way that this is accommodated is that the inner function is turned into one which @@ -332,7 +332,12 @@ def bootstrap( elif resample_method == "cycle": resample_fn = resample_cycle - def bootstrapped(*args, n: int = 20, prior_adjustment=1, **kwargs: Incomplete): + def bootstrapped( + *args, + n: int = 20, + prior_adjustment: int = 1, + **kwargs: Incomplete, + ): # examine args to determine which to resample resample_indices = [ i diff --git a/arpes/io.py b/arpes/io.py index b5bd5e9c..57bed714 100644 --- a/arpes/io.py +++ b/arpes/io.py @@ -172,20 +172,14 @@ def stitch( Returns: The concatenated data. """ - list_of_files = None - if isinstance(df_or_list, pd.DataFrame): - list_of_files = list(df_or_list.index) - else: - if not isinstance(df_or_list, list | tuple): - msg = "Expected an interable for a list of the scans to stitch together" - raise TypeError(msg) - list_of_files = list(df_or_list) + list_of_files = _df_or_list_to_files(df_or_list) if not built_axis_name: assert isinstance(attr_or_axis, str) built_axis_name = attr_or_axis if not list_of_files: msg = "Must supply at least one file to stitch" raise ValueError(msg) + # loaded = [ f if isinstance(f, xr.DataArray | xr.Dataset) else load_data(f) for f in list_of_files ] @@ -221,6 +215,26 @@ def stitch( return concatenated +def _df_or_list_to_files( + df_or_list: list[str] | pd.DataFrame, +) -> list[str]: + """Helper function for stitch. + + Args: + df_or_list(pd.DataFrame, list): input data file + + Returns: (list[str]) + list of files to stitch. + """ + if isinstance(df_or_list, pd.DataFrame): + return list(df_or_list.index) + assert not isinstance( + df_or_list, + list | tuple, + ), "Expected an interable for a list of the scans to stitch together" + return list(df_or_list) + + def file_for_pickle(name: str) -> Path | str: here = Path() from arpes.config import CONFIG diff --git a/arpes/preparation/axis_preparation.py b/arpes/preparation/axis_preparation.py index 4550744a..71b99993 100644 --- a/arpes/preparation/axis_preparation.py +++ b/arpes/preparation/axis_preparation.py @@ -182,8 +182,7 @@ def normalize(arr: xr.Dataset | xr.DataArray) -> xr.DataArray: def transform_dataarray_axis( func: Callable[..., ...], - old_axis_name: str, - new_axis_name: str, + old_and_new_axis_names: tuple[str, str], new_axis: NDArray[np.float_] | xr.DataArray, dataset: xr.Dataset, prep_name: Callable[[str], str], @@ -195,14 +194,15 @@ def transform_dataarray_axis( Args: func ([TODO:type]): [TODO:description] - old_axis_name(str): [TODO:description] - new_axis_name(str): [TODO:description] + old_and_new_axis_names (tuple[str, str]) : old and new axis names as the tuple form new_axis ([TODO:type]): [TODO:description] dataset(xr.Dataset): [TODO:description] prep_name ([TODO:type]): [TODO:description] transform_spectra ([TODO:type]): [TODO:description] remove_old ([TODO:type]): [TODO:description] """ + old_axis_name, new_axis_name = old_and_new_axis_names + ds = dataset.copy() if transform_spectra is None: # transform *all* DataArrays in the dataset that have old_axis_name in their dimensions. diff --git a/arpes/preparation/coord_preparation.py b/arpes/preparation/coord_preparation.py index 259e8725..36a234ed 100644 --- a/arpes/preparation/coord_preparation.py +++ b/arpes/preparation/coord_preparation.py @@ -8,6 +8,8 @@ import numpy as np if TYPE_CHECKING: + from collections.abc import Sequence + import xarray as xr __all__ = ["disambiguate_coordinates"] @@ -15,8 +17,8 @@ def disambiguate_coordinates( datasets: xr.Dataset, - possibly_clashing_coordinates, -): + possibly_clashing_coordinates: Sequence[str], +) -> list[xr.DataArray]: """Finds and unifies duplicated coordinates or ambiguous coordinates. This is useful if two regions claim to have an energy axis, but one is a core level diff --git a/arpes/preparation/tof_preparation.py b/arpes/preparation/tof_preparation.py index 0a119c25..9b513c09 100644 --- a/arpes/preparation/tof_preparation.py +++ b/arpes/preparation/tof_preparation.py @@ -245,8 +245,7 @@ def process_SToF(dataset: xr.Dataset) -> xr.Dataset: dataset = transform_dataarray_axis( build_KE_coords_to_time_coords(dataset, ke_axis), - "time", - "eV", + ("time", "eV"), ke_axis, dataset, lambda x: x, @@ -276,8 +275,7 @@ def process_DLD(dataset: xr.Dataset) -> xr.Dataset: ) return transform_dataarray_axis( build_KE_coords_to_time_pixel_coords(dataset, ke_axis), - "t_pixels", - "kinetic", + ("t_pixels", "kinetic"), ke_axis, dataset, lambda: "kinetic_spectrum", diff --git a/arpes/utilities/geometry.py b/arpes/utilities/geometry.py index c7b7ecd0..9ec7f8d3 100644 --- a/arpes/utilities/geometry.py +++ b/arpes/utilities/geometry.py @@ -73,7 +73,12 @@ def segment_contains_point( return 0 - epsilon < delta.dot(delta_p) / delta.dot(delta) < 1 + epsilon -def polyhedron_intersect_plane(poly_faces, plane_normal, plane_point, epsilon: float = 1e-6): +def polyhedron_intersect_plane( + poly_faces: list[NDArray[np.float_]], + plane_normal: NDArray[np.float_], + plane_point: NDArray[np.float_], + epsilon: float = 1e-6, +) -> ConvexHull: """Determines the intersection of a convex polyhedron intersecting a plane. The polyhedron faces should be given by a list of np.arrays, where each np.array at @@ -94,7 +99,7 @@ def polyhedron_intersect_plane(poly_faces, plane_normal, plane_point, epsilon: f """ collected_points = [] - def add_point(c): + def add_point(c: NDArray[np.float_]) -> None: already_collected = False for other in collected_points: delta = c - other @@ -106,7 +111,9 @@ def add_point(c): collected_points.append(c) for poly_face in poly_faces: - segments = list(zip(poly_face, np.concatenate([poly_face[1:], [poly_face[0]]]))) + segments = list( + zip(poly_face, np.concatenate([poly_face[1:], [poly_face[0]]]), strict=True), + ) for a, b in segments: intersection = point_plane_intersection( plane_normal, diff --git a/arpes/widgets.py b/arpes/widgets.py index 3e568942..00d9e0a2 100644 --- a/arpes/widgets.py +++ b/arpes/widgets.py @@ -29,6 +29,7 @@ import itertools import pathlib +import pprint import warnings from collections.abc import Sequence from functools import wraps @@ -38,6 +39,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import pyperclip import xarray as xr from matplotlib import gridspec from matplotlib.axes import Axes @@ -340,7 +342,7 @@ def mask_cmap(self) -> Colormap: return self._mask_cmap @property - def mask(self): + def mask(self): # noqa: ANN202 return self._mask @mask.setter @@ -468,10 +470,6 @@ def on_add_new_peak(selection) -> None: ctx["data"] = data def on_copy_settings(event: MouseEvent) -> None: - import pprint - - import pyperclip - pyperclip.copy(pprint.pformat(compute_parameters())) copy_settings_button = Button(ax_test, "Copy Settings") @@ -489,7 +487,7 @@ def pca_explorer( *, transpose_mask: bool = False, ) -> CURRENTCONTEXT: - """A tool providing PCA decomposition exploration of a dataset. + """A tool providing PCA (Principal component analysis) decomposition exploration of a dataset. Args: pca: The decomposition of the data, the output of an sklearn PCA decomp. @@ -517,7 +515,7 @@ def pca_explorer( } arpes.config.CONFIG["CURRENT_CONTEXT"] = context - def compute_for_scatter(): + def compute_for_scatter() -> tuple[xr.DataArray | xr.Dataset, int]: for_scatter = pca.copy(deep=True).isel( **dict([[component_dim, context["selected_components"]]]), ) @@ -716,18 +714,7 @@ def compute_offsets() -> dict[str, float]: return {k: v.val for k, v in sliders.items()} def on_copy_settings(event: MouseEvent) -> None: - try: - import pprint - - import pyperclip - - pyperclip.copy(pprint.pformat(compute_offsets())) - except ImportError: - pass - finally: - import pprint - - print(pprint.pformat(compute_offsets())) + pyperclip.copy(pprint.pformat(compute_offsets())) def apply_offsets(event: MouseEvent) -> None: for name, offset in compute_offsets().items():