diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 1ffa1334..2aaf4885 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -108,6 +108,20 @@ ANGLE_VARS = ("alpha", "beta", "chi", "psi", "phi", "theta") +DEFAULT_RADII = { + "kp": 0.02, + "kz": 0.05, + "phi": 0.02, + "beta": 0.02, + "theta": 0.02, + "eV": 0.05, + "delay": 0.2, + "T": 2, + "temperature": 2, +} + +UNSPESIFIED = 0.1 + LOGLEVELS = (DEBUG, INFO) LOGLEVEL = LOGLEVELS[1] logger = getLogger(__name__) @@ -373,10 +387,9 @@ def transpose_to_back(self, dim: str) -> xr.DataArray | xr.Dataset: def select_around_data( self, - points: dict[str, Any] | xr.Dataset, + points: dict[str, float] | xr.Dataset, radius: dict[str, float] | float | None = None, # radius={"phi": 0.005} *, - fast: bool = False, mode: Literal["sum", "mean"] = "sum", **kwargs: Incomplete, ) -> xr.DataArray: @@ -400,8 +413,6 @@ def select_around_data( points: The set of points where the selection should be performed. radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized selection will be made as a compromise. - fast: If true, uses a rectangular rather than a circular region for selectioIf true, - uses a rectangular rather than a circular region for selection. mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum" kwargs: Can be used to pass radii parameters by keyword with `_r` postfix. @@ -421,20 +432,6 @@ def select_around_data( if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} - default_radii = { - "kp": 0.02, - "kz": 0.05, - "phi": 0.02, - "beta": 0.02, - "theta": 0.02, - "eV": 0.05, - "delay": 0.2, - "T": 2, - "temperature": 2, - } - - UNSPESIFIED = 0.1 - if isinstance(radius, float): radius = {str(d): radius for d in points} else: @@ -443,15 +440,15 @@ def select_around_data( ) if collected_terms: radius = { - str(d): kwargs.get(f"{d}_r", default_radii.get(str(d), UNSPESIFIED)) + str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points } elif radius is None: - radius = {str(d): default_radii.get(str(d), UNSPESIFIED) for d in points} + radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} assert isinstance(radius, dict) radius = { - str(d): radius.get(str(d), default_radii.get(str(d), UNSPESIFIED)) for d in points + str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points } along_dims = next(iter(points.values())).dims @@ -473,18 +470,15 @@ def select_around_data( radius = {d: v for d, v in radius.items() if d not in nearest_sel_params} # -- to heari, but as name said, should be alwayws safe. - if fast: - selection_slices = { - d: slice( - points[d].sel(**coord) - radius[d], - points[d].sel(**coord) + radius[d], - ) - for d in points - if d in radius - } - selected = value.sel(**selection_slices) - else: - raise NotImplementedError + selection_slices = { + d: slice( + points[d].sel(**coord) - radius[d], + points[d].sel(**coord) + radius[d], + ) + for d in points + if d in radius + } + selected = value.sel(**selection_slices) if nearest_sel_params: selected = selected.sel(**nearest_sel_params, method="nearest") @@ -502,10 +496,9 @@ def select_around_data( def select_around( self, - point: dict[str, Any] | xr.Dataset, + points: dict[str, float] | xr.Dataset, radius: dict[str, float] | float | None = None, *, - fast: bool = False, mode: Literal["sum", "mean"] = "sum", **kwargs: Incomplete, ) -> xr.DataArray: @@ -522,11 +515,9 @@ def select_around( then we will try to use reasonable default values; buyer beware. Args: - point: The points where the selection should be performed. + points: The points where the selection should be performed. radius: The radius of the selection in each coordinate. If dimensions are omitted, a standard sized selection will be made as a compromise. - fast: If true, uses a rectangular rather than a circular region for selectioIf true, - uses a rectangular rather than a circular region for selection. safe: If true, infills radii with default values. Defaults to `True`. mode: How the reduction should be performed, one of "sum" or "mean". Defaults to "sum" **kwargs: Can be used to pass radii parameters by keyword with `_r` postfix. @@ -538,62 +529,50 @@ def select_around( self._obj, xr.DataArray, ), "Cannot use select_around on Datasets only DataArrays!" - assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean." - if isinstance(point, tuple | list): - warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2) - point = dict(zip(point, self._obj.dims, strict=True)) - if isinstance(point, xr.Dataset): - point = {k: point[k].item() for k in point.data_vars} - - default_radii = { - "kp": 0.02, - "kz": 0.05, - "phi": 0.02, - "beta": 0.02, - "theta": 0.02, - "eV": 0.05, - "delay": 0.2, - "T": 2, - } - unspecified = 0.1 + assert mode in {"sum", "mean"}, "mode parameter should be either sum or mean." + if isinstance(points, tuple | list): + warnings.warn("Dangerous iterable point argument to `select_around`", stacklevel=2) + points = dict(zip(points, self._obj.dims, strict=True)) + if isinstance(points, xr.Dataset): + points = {k: points[k].item() for k in points.data_vars} + logger.debug(f"points: {points}") if isinstance(radius, float): - radius = {str(d): radius for d in point} + radius = {str(d): radius for d in points} else: - collected_terms = {f"{k}_r" for k in point}.intersection( + collected_terms = {f"{k}_r" for k in points}.intersection( set(kwargs.keys()), ) if collected_terms: radius = { - str(d): kwargs.get(f"{d}_r", default_radii.get(str(d), unspecified)) - for d in point + str(d): kwargs.get(f"{d}_r", DEFAULT_RADII.Get(str(d), UNSPESIFIED)) + for d in points } elif radius is None: - radius = {str(d): default_radii.get(str(d), unspecified) for d in point} + radius = {str(d): DEFAULT_RADII.get(str(d), UNSPESIFIED) for d in points} assert isinstance(radius, dict) - radius = {str(d): radius.get(str(d), default_radii.get(str(d), unspecified)) for d in point} + radius = { + str(d): radius.get(str(d), DEFAULT_RADII.get(str(d), UNSPESIFIED)) for d in points + } - # make sure we are taking at least one pixel along each + logger.debug(f"radius: {radius}") nearest_sel_params = {} # -- originally, if safe == True, the following liens starting from hear stride = self._obj.G.stride(generic_dim_names=False) for d, v in radius.items(): if v < stride[d]: - nearest_sel_params[d] = point[d] + nearest_sel_params[d] = points[d] radius = {d: v for d, v in radius.items() if d not in nearest_sel_params} # -- to heari, but as name said, should be alwayws safe. - if fast: - selection_slices = { - d: slice(point[d] - radius[d], point[d] + radius[d]) for d in point if d in radius - } - selected = self._obj.sel(**selection_slices) - else: - raise NotImplementedError + selection_slices = { + d: slice(points[d] - radius[d], points[d] + radius[d]) for d in points if d in radius + } + selected = self._obj.sel(**selection_slices) if nearest_sel_params: selected = selected.sel(**nearest_sel_params, method="nearest") @@ -604,9 +583,7 @@ def select_around( if mode == "sum": return selected.sum(list(radius.keys())) - if mode == "mean": - return selected.mean(list(radius.keys())) - raise RuntimeError + return selected.mean(list(radius.keys())) def short_history(self, key: str = "by") -> list: """Return the short version of history. @@ -2853,7 +2830,19 @@ def iter_coords( for ts in itertools.product(*[self._obj.coords[d].values for d in dim_names]): yield dict(zip(dim_names, ts, strict=True)) - def range(self, *, generic_dim_names: bool = True) -> dict: + def range( + self, + *, + generic_dim_names: bool = True, + ) -> dict[str, tuple[float, float]]: + """Return the maximum/minimum value in each dimension. + + Args: + generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used. + + Returns: (dict[str, tuple[float, float]]) + The range of each dimension. + """ assert isinstance(self._obj, xr.DataArray | xr.Dataset) indexed_coords = [self._obj.coords[d] for d in self._obj.dims] indexed_ranges = [(np.min(coord.values), np.max(coord.values)) for coord in indexed_coords] @@ -2864,7 +2853,24 @@ def range(self, *, generic_dim_names: bool = True) -> dict: return dict(zip(dim_names, indexed_ranges, strict=True)) - def stride(self, *args: Incomplete, generic_dim_names: bool = True) -> dict | list: + def stride( + self, + *args: str | list[str] | tuple[str, ...], + generic_dim_names: bool = True, + ) -> dict[str, float] | list[float] | float: + """Return the stride in each dimension. + + Note that the stride defined in this method is just a difference between first two values. + In most case, this treatment does not cause a problem. However, when the data has been + concatenated, this assumption may not be not valid. + + Args: + args: The dimension to return. ["eV", "phi"] or "eV", "phi" + generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used. + + Returns: + The stride of each dimension + """ assert isinstance(self._obj, xr.DataArray | xr.Dataset) indexed_coords: list[xr.DataArray] = [self._obj.coords[d] for d in self._obj.dims] indexed_strides: list[float] = [ @@ -2876,14 +2882,11 @@ def stride(self, *args: Incomplete, generic_dim_names: bool = True) -> dict | li dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)] result = dict(zip(dim_names, indexed_strides, strict=True)) - if args: if len(args) == 1: - if not isinstance(args[0], str): - # if passed list of strs as argument + if not isinstance(args[0], str): # suppose args is list / tuple result = [result[selected_names] for selected_names in args[0]] else: - # if passed single name as argument result = result[args[0]] else: # if passed several names as arguments diff --git a/docs/source/notebooks/custom-dot-s-functionality.ipynb b/docs/source/notebooks/custom-dot-s-functionality.ipynb index c9b8b2d2..b7ad26f1 100644 --- a/docs/source/notebooks/custom-dot-s-functionality.ipynb +++ b/docs/source/notebooks/custom-dot-s-functionality.ipynb @@ -787,7 +787,7 @@ "# take the Lorentzian (component `b`) center parameter\n", "phi_values = phis.F.p(\"b_center\")\n", "temp_dep.spectrum.S.select_around_data(\n", - " {\"phi\": phi_values}, mode=\"mean\", fast=True, radius={\"phi\": 0.005}\n", + " {\"phi\": phi_values}, mode=\"mean\", radius={\"phi\": 0.005}\n", ").S.plot()" ] },