From 5ff8d78ee91ed71dce703f44f5424b7ef530452f Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sun, 28 Apr 2024 10:24:18 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=20Sort=20the=20methods=20in=20G?= =?UTF-8?q?=20class.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the future they are split into three --- src/arpes/analysis/band_analysis.py | 11 +- src/arpes/models/band.py | 32 ++- src/arpes/xarray_extensions.py | 339 ++++++++++++++-------------- 3 files changed, 196 insertions(+), 186 deletions(-) diff --git a/src/arpes/analysis/band_analysis.py b/src/arpes/analysis/band_analysis.py index d5aa8f66..384eaf1c 100644 --- a/src/arpes/analysis/band_analysis.py +++ b/src/arpes/analysis/band_analysis.py @@ -168,15 +168,15 @@ def unpack_bands_from_fit( label = identified_band_results.loc[first_coordinate].values.item()[i] def dataarray_for_value(param_name: str, i: int = i, *, is_value: bool) -> xr.DataArray: - """[TODO:summary]. + """Return DataArray representing the fit results. Args: param_name (str): [TODO:description] i (int): [TODO:description] is_value (bool): [TODO:description] """ - values: NDArray[np.float_] = np.ndarray( - shape=identified_band_results.values.shape, + values: NDArray[np.float_] = np.zeros_like( + identified_band_results.values, dtype=float, ) it = np.nditer(values, flags=["multi_index"], op_flags=[["writeonly"]]) @@ -185,7 +185,6 @@ def dataarray_for_value(param_name: str, i: int = i, *, is_value: bool) -> xr.Da param = band_results.values[it.multi_index].params[prefix + param_name] it[0] = param.value if is_value else param.stderr it.iternext() - return xr.DataArray( values, identified_band_results.coords, @@ -219,8 +218,8 @@ def _identified_band_results_etc( return of broadcast_model().results weights (tuple[float, float, float]): weight values for sigma, amplitude, center - Returns: - [TODO:description] + Returns: tuple[xr.DataArray, dict[Hashable, float], list[str]] + identified_band_results, first_coordinate, prefixes """ band_results = band_results if isinstance(band_results, xr.DataArray) else band_results.results prefixes = [component.prefix for component in band_results.values[0].model.components] diff --git a/src/arpes/models/band.py b/src/arpes/models/band.py index db6e25f6..b1828778 100644 --- a/src/arpes/models/band.py +++ b/src/arpes/models/band.py @@ -42,7 +42,7 @@ class Band: Attribute: label (str): label of the band. - _data (xr.Dataset): consists of several DataArrays representing the fitting results. + _data (xr.Dataset): Dataset consists of several DataArrays representing the fitting results. `data_vars` are "center", "center_stderr", "amplitude", "amplitude_stdrr", "sigma", and "sigma_stderr" """ @@ -65,16 +65,7 @@ def velocity(self) -> xr.DataArray: """ spacing = float(self.coords[self.dims[0]][1] - self.coords[self.dims[0]][0]) - def embed_nan(values: NDArray[np.float_], padding: int) -> NDArray[np.float_]: - embedded: NDArray[np.float_] = np.full( - shape=(values.shape[0] + 2 * padding,), - fill_value=np.nan, - dtype=np.float_, - ) - embedded[padding:-padding] = values - return embedded - - raw_values = embed_nan(self.center.values, 50) + raw_values = self.embed_nan(self.center.values, 50) masked = np.copy(raw_values) masked[raw_values != raw_values] = 0 @@ -178,6 +169,25 @@ def dims(self) -> tuple[str, ...]: assert isinstance(self._data, xr.Dataset) return self._data.center.dims + @staticmethod + def embed_nan(values: NDArray[np.float_], padding: int) -> NDArray[np.float_]: + """Return np.ndarray padding before and after the original NDArray with nan. + + Args: + values: [TODO:description] + padding: the length of the padding + + Returns: NDArray[np.float_] + [TODO:description] + """ + embedded: NDArray[np.float_] = np.full( + shape=(values.shape[0] + 2 * padding,), + fill_value=np.nan, + dtype=np.float_, + ) + embedded[padding:-padding] = values + return embedded + class MultifitBand(Band): """Convenience class that reimplements reading data out of a composite fit result.""" diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 4d944958..709b148e 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -2201,16 +2201,6 @@ def round_coordinates( return rounded - def argmax_coords(self) -> dict[Hashable, float]: # TODO [RA]: DaraArray - """Return dict representing the position for maximum value.""" - assert isinstance(self._obj, xr.DataArray) - data: xr.DataArray = self._obj - raveled = data.argmax(None) - assert isinstance(raveled, xr.DataArray) - idx = raveled.item() - flat_indices = np.unravel_index(idx, data.values.shape) - return {d: data.coords[d][flat_indices[i]].item() for i, d in enumerate(data.dims)} - def apply_over( self, fn: Callable, @@ -2319,15 +2309,6 @@ def transform_coords( return copied - def filter_vars( - self, f: Callable[[Hashable, xr.DataArray], bool] - ) -> xr.Dataset: # TODO [RA]: Dataset only - assert isinstance(self._obj, xr.Dataset) # ._obj.data_vars - return xr.Dataset( - data_vars={k: v for k, v in self._obj.data_vars.items() if f(k, v)}, - attrs=self._obj.attrs, - ) - def coordinatize(self, as_coordinate_name: str | None = None) -> XrTypes: """Copies data into a coordinate's data, with an optional renaming. @@ -2357,6 +2338,176 @@ def coordinatize(self, as_coordinate_name: str | None = None) -> XrTypes: return o + def enumerate_iter_coords( + self, + ) -> Generator[tuple[tuple[int, ...], dict[Hashable, float]], None, None]: + """[TODO:summary]. + + Returns: + Generator of the following data + ((0, 0), {'phi': -0.2178031280148764, 'eV': 9.0}) + which shows the relationship between pixel position and physical (like "eV" and "phi"). + """ + assert isinstance(self._obj, xr.DataArray | xr.Dataset) + coords_list = [self._obj.coords[d].values for d in self._obj.dims] + for indices in itertools.product(*[range(len(c)) for c in coords_list]): + cut_coords = [cs[index] for cs, index in zip(coords_list, indices, strict=True)] + yield indices, dict(zip(self._obj.dims, cut_coords, strict=True)) + + def iter_coords( + self, + dim_names: tuple[str | Hashable, ...] = (), + ) -> Iterator[dict[Hashable, float]]: + """[TODO:summary]. + + Args: + dim_names: [TODO:description] + + Returns: + Generator of the physical position like ("eV" and "phi") + {'phi': -0.2178031280148764, 'eV': 9.002} + """ + if not dim_names: + dim_names = tuple(self._obj.dims) + for ts in itertools.product(*[self._obj.coords[d].values for d in dim_names]): + yield dict(zip(dim_names, ts, strict=True)) + + def range( + self, + *, + generic_dim_names: bool = True, + ) -> dict[Hashable, tuple[float, float]]: + """Return the maximum/minimum value in each dimension. + + Args: + generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used. + + Returns: (dict[str, tuple[float, float]]) + The range of each dimension. + """ + indexed_coords = [self._obj.coords[d] for d in self._obj.dims] + indexed_ranges = [(coord.min().item(), coord.max().item()) for coord in indexed_coords] + + dim_names: list[str] | tuple[Hashable, ...] = tuple(self._obj.dims) + if generic_dim_names: + dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)] + + return dict(zip(dim_names, indexed_ranges, strict=True)) + + def stride( + self, + *args: str | list[str] | tuple[str, ...], + generic_dim_names: bool = True, + ) -> dict[Hashable, float] | list[float] | float: + """Return the stride in each dimension. + + Note that the stride defined in this method is just a difference between first two values. + In most case, this treatment does not cause a problem. However, when the data has been + concatenated, this assumption may not be not valid. + + Args: + args: The dimension to return. ["eV", "phi"] or "eV", "phi" + generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used. + + Returns: + The stride of each dimension + """ + indexed_coords: list[xr.DataArray] = [self._obj.coords[d] for d in self._obj.dims] + indexed_strides: list[float] = [ + coord.values[1] - coord.values[0] for coord in indexed_coords + ] + + dim_names: list[str] | tuple[Hashable, ...] = tuple(self._obj.dims) + if generic_dim_names: + dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)] + + result: dict[Hashable, float] = dict(zip(dim_names, indexed_strides, strict=True)) + if args: + if isinstance(args[0], str): + return ( + result[args[0]] + if len(args) == 1 + else [result[str(selected_names)] for selected_names in args] + ) + return [result[selected_names] for selected_names in args[0]] + return result + + def filter_coord( + self, + coordinate_name: str, + sieve: Callable[[Any, XrTypes], bool], + ) -> XrTypes: + """Filters a dataset along a coordinate. + + Sieve should be a function which accepts a coordinate value and the slice + of the data along that dimension. + + Internally, the predicate function `sieve` is applied to the coordinate and slice to + generate a mask. The mask is used to select from the data after iteration. + + An improvement here would support filtering over several coordinates. + + Args: + coordinate_name: The coordinate which should be filtered. + sieve: A predicate to be applied to the coordinate and data at that coordinate. + + Returns: + A subset of the data composed of the slices which make the `sieve` predicate `True`. + """ + mask = np.array( + [ + i + for i, c in enumerate(self._obj.coords[coordinate_name]) + if sieve(c, self._obj.isel({coordinate_name: i})) + ], + ) + return self._obj.isel({coordinate_name: mask}) + + def iterate_axis( + self, + axis_name_or_axes: list[str] | str, + ) -> Generator[tuple[dict[str, float], XrTypes], str, None]: + """Generator to extract data for along the specified axis. + + Args: + axis_name_or_axes (list[str] | str): axis (dime) name for iteration. + + Returns: (tuple[dict[str, float], XrTypes]) + dict object represents the axis(dim) name and it's value. + XrTypes object the corresponding data, the value at the corresponding position. + """ + assert isinstance(self._obj, xr.DataArray | xr.Dataset) + if isinstance(axis_name_or_axes, str): + axis_name_or_axes = [axis_name_or_axes] + + coord_iterators: list[NDArray[np.float_]] = [ + self._obj.coords[d].values for d in axis_name_or_axes + ] + for indices in itertools.product(*[range(len(c)) for c in coord_iterators]): + cut_coords = [cs[index] for cs, index in zip(coord_iterators, indices, strict=True)] + coords_dict = dict(zip(axis_name_or_axes, cut_coords, strict=True)) + yield coords_dict, self._obj.sel(coords_dict, method="nearest") + + def filter_vars( + self, + f: Callable[[Hashable, xr.DataArray], bool], + ) -> xr.Dataset: # TODO: [RA] Dataset only + assert isinstance(self._obj, xr.Dataset) # ._obj.data_vars + return xr.Dataset( + data_vars={k: v for k, v in self._obj.data_vars.items() if f(k, v)}, + attrs=self._obj.attrs, + ) + + def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray + """Return dict representing the position for maximum value.""" + assert isinstance(self._obj, xr.DataArray) + data: xr.DataArray = self._obj + raveled = data.argmax(None) + assert isinstance(raveled, xr.DataArray) + idx = raveled.item() + flat_indices = np.unravel_index(idx, data.values.shape) + return {d: data.coords[d][flat_indices[i]].item() for i, d in enumerate(data.dims)} + def ravel(self) -> Mapping[Hashable, xr.DataArray | NDArray[np.float_]]: # TODO: [RA] DataArray """Converts to a flat representation where the coordinate values are also present. @@ -2452,62 +2603,6 @@ def as_movie( assert isinstance(out, str) return plot_movie(self._obj, time_dim, out=out, **kwargs) - def filter_coord( - self, - coordinate_name: str, - sieve: Callable[[Any, XrTypes], bool], - ) -> XrTypes: - """Filters a dataset along a coordinate. - - Sieve should be a function which accepts a coordinate value and the slice - of the data along that dimension. - - Internally, the predicate function `sieve` is applied to the coordinate and slice to - generate a mask. The mask is used to select from the data after iteration. - - An improvement here would support filtering over several coordinates. - - Args: - coordinate_name: The coordinate which should be filtered. - sieve: A predicate to be applied to the coordinate and data at that coordinate. - - Returns: - A subset of the data composed of the slices which make the `sieve` predicate `True`. - """ - mask = np.array( - [ - i - for i, c in enumerate(self._obj.coords[coordinate_name]) - if sieve(c, self._obj.isel({coordinate_name: i})) - ], - ) - return self._obj.isel({coordinate_name: mask}) - - def iterate_axis( - self, - axis_name_or_axes: list[str] | str, - ) -> Generator[tuple[dict[str, float], XrTypes], str, None]: - """Generator to extract data for along the specified axis. - - Args: - axis_name_or_axes (list[str] | str): axis (dime) name for iteration. - - Returns: (tuple[dict[str, float], XrTypes]) - dict object represents the axis(dim) name and it's value. - XrTypes object the corresponding data, the value at the corresponding position. - """ - assert isinstance(self._obj, xr.DataArray | xr.Dataset) - if isinstance(axis_name_or_axes, str): - axis_name_or_axes = [axis_name_or_axes] - - coord_iterators: list[NDArray[np.float_]] = [ - self._obj.coords[d].values for d in axis_name_or_axes - ] - for indices in itertools.product(*[range(len(c)) for c in coord_iterators]): - cut_coords = [cs[index] for cs, index in zip(coord_iterators, indices, strict=True)] - coords_dict = dict(zip(axis_name_or_axes, cut_coords, strict=True)) - yield coords_dict, self._obj.sel(coords_dict, method="nearest") - def map_axes( self, axes: list[str] | str, @@ -2646,100 +2741,6 @@ def map( assert isinstance(self._obj, xr.DataArray) return apply_dataarray(self._obj, np.vectorize(fn, **kwargs)) - def enumerate_iter_coords( - self, - ) -> Generator[tuple[tuple[int, ...], dict[Hashable, float]], None, None]: - """[TODO:summary]. - - Returns: - Generator of the following data - ((0, 0), {'phi': -0.2178031280148764, 'eV': 9.0}) - which shows the relationship between pixel position and physical (like "eV" and "phi"). - """ - assert isinstance(self._obj, xr.DataArray | xr.Dataset) - coords_list = [self._obj.coords[d].values for d in self._obj.dims] - for indices in itertools.product(*[range(len(c)) for c in coords_list]): - cut_coords = [cs[index] for cs, index in zip(coords_list, indices, strict=True)] - yield indices, dict(zip(self._obj.dims, cut_coords, strict=True)) - - def iter_coords( - self, - dim_names: tuple[str | Hashable, ...] = (), - ) -> Iterator[dict[Hashable, float]]: - """[TODO:summary]. - - Args: - dim_names: [TODO:description] - - Returns: - Generator of the physical position like ("eV" and "phi") - {'phi': -0.2178031280148764, 'eV': 9.002} - """ - if not dim_names: - dim_names = tuple(self._obj.dims) - for ts in itertools.product(*[self._obj.coords[d].values for d in dim_names]): - yield dict(zip(dim_names, ts, strict=True)) - - def range( - self, - *, - generic_dim_names: bool = True, - ) -> dict[Hashable, tuple[float, float]]: - """Return the maximum/minimum value in each dimension. - - Args: - generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used. - - Returns: (dict[str, tuple[float, float]]) - The range of each dimension. - """ - indexed_coords = [self._obj.coords[d] for d in self._obj.dims] - indexed_ranges = [(coord.min().item(), coord.max().item()) for coord in indexed_coords] - - dim_names: list[str] | tuple[Hashable, ...] = tuple(self._obj.dims) - if generic_dim_names: - dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)] - - return dict(zip(dim_names, indexed_ranges, strict=True)) - - def stride( - self, - *args: str | list[str] | tuple[str, ...], - generic_dim_names: bool = True, - ) -> dict[Hashable, float] | list[float] | float: - """Return the stride in each dimension. - - Note that the stride defined in this method is just a difference between first two values. - In most case, this treatment does not cause a problem. However, when the data has been - concatenated, this assumption may not be not valid. - - Args: - args: The dimension to return. ["eV", "phi"] or "eV", "phi" - generic_dim_names (bool): if True, use Generic dimension name, such as 'x', is used. - - Returns: - The stride of each dimension - """ - indexed_coords: list[xr.DataArray] = [self._obj.coords[d] for d in self._obj.dims] - indexed_strides: list[float] = [ - coord.values[1] - coord.values[0] for coord in indexed_coords - ] - - dim_names: list[str] | tuple[Hashable, ...] = tuple(self._obj.dims) - if generic_dim_names: - dim_names = NORMALIZED_DIM_NAMES[: len(dim_names)] - - result: dict[Hashable, float] = dict(zip(dim_names, indexed_strides, strict=True)) - if args: - if isinstance(args[0], str): - return ( - result[args[0]] - if len(args) == 1 - else [result[str(selected_names)] for selected_names in args] - ) - return [result[selected_names] for selected_names in args[0]] - return result - def shift_by( # noqa: PLR0913 self, other: xr.DataArray | NDArray[np.float_],