From b9e534566264ac9f31800965f2507727b6a0e8a6 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Mon, 29 Apr 2024 10:26:49 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=20residual=20should=20not=20be?= =?UTF-8?q?=20copied=20from=20data.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit params (especially provenance) should not be transfered. --- src/arpes/analysis/deconvolution.py | 9 ++++++--- src/arpes/fits/utilities.py | 24 +++++++++++------------- src/arpes/xarray_extensions.py | 28 ++++++++++++++++++++++++++++ tests/test_curve_fitting.py | 2 +- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/arpes/analysis/deconvolution.py b/src/arpes/analysis/deconvolution.py index 42080798..443ba770 100644 --- a/src/arpes/analysis/deconvolution.py +++ b/src/arpes/analysis/deconvolution.py @@ -104,9 +104,12 @@ def deconvolve_rl( The Richardson-Lucy deconvolved data. """ arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) - data_image = arr.values - psf_ = psf.values - im_deconv = richardson_lucy(data_image, psf_, num_iter=n_iterations, filter_epsilon=None) + im_deconv = richardson_lucy( + arr.values, + psf.values, + num_iter=n_iterations, + filter_epsilon=None, + ) return arr.S.with_values(im_deconv) diff --git a/src/arpes/fits/utilities.py b/src/arpes/fits/utilities.py index b15e5a58..6d935e98 100644 --- a/src/arpes/fits/utilities.py +++ b/src/arpes/fits/utilities.py @@ -164,21 +164,19 @@ def broadcast_model( # noqa: PLR0913 broadcast_dims = [broadcast_dims] logger.debug("Normalizing to spectrum") - data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) cs = {} for dim in broadcast_dims: - cs[dim] = data_array.coords[dim] + cs[dim] = data.coords[dim] - other_axes = set(data_array.dims).difference(set(broadcast_dims)) - template = data_array.sum(list(other_axes)) + other_axes = set(data.dims).difference(set(broadcast_dims)) + template = data.sum(list(other_axes)) template.values = np.ndarray(template.shape, dtype=object) n_fits = np.prod(np.array(list(template.sizes.values()))) if parallelize is None: parallelize = bool(n_fits > 20) # noqa: PLR2004 - residual = data_array.copy(deep=True) - logger.debug("Copying residual") - residual.values = np.zeros(residual.shape) + residual = xr.DataArray(np.zeros_like(data.values), coords=data.coords, dims=data.dims) logger.debug("Parsing model") model = parse_model(model_cls) @@ -190,7 +188,7 @@ def broadcast_model( # noqa: PLR0913 serialize = parallelize assert isinstance(serialize, bool) fitter = mp_fits.MPWorker( - data=data_array, + data=data, uncompiled_model=model, prefixes=prefixes, params=params, @@ -209,7 +207,7 @@ def broadcast_model( # noqa: PLR0913 exe_results = list( wrap_progress( pool.imap(fitter, template.G.iter_coords()), # IMapIterator - total=n_fits, + total=int(n_fits), desc="Fitting on pool...", ), ) @@ -219,7 +217,7 @@ def broadcast_model( # noqa: PLR0913 for _, cut_coords in wrap_progress( template.G.enumerate_iter_coords(), desc="Fitting", - total=n_fits, + total=int(n_fits), ): exe_results.append(fitter(cut_coords)) @@ -241,9 +239,9 @@ def unwrap(result_data: str) -> object: # (Unpickler) return xr.Dataset( { "results": template, - "data": data_array, + "data": data, "residual": residual, - "norm_residual": residual / data_array, + "norm_residual": residual / data, }, residual.coords, ) @@ -254,7 +252,7 @@ def _fake_wqdm(x: Iterable[int], **kwargs: str | float) -> Iterable[int]: Args: x (Iterable[int]): [TODO:description] - kwargs: its a dummy parameter, which is not used. + kwargs: its dummy parameters, not used. Returns: Same iterable. diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 709b148e..45ba9fdd 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -58,6 +58,7 @@ Unpack, ) +from lmfit.model import propagate_err import matplotlib.pyplot as plt import numpy as np import xarray as xr @@ -2488,6 +2489,7 @@ def iterate_axis( coords_dict = dict(zip(axis_name_or_axes, cut_coords, strict=True)) yield coords_dict, self._obj.sel(coords_dict, method="nearest") + # --------- def filter_vars( self, f: Callable[[Hashable, xr.DataArray], bool], @@ -2498,6 +2500,7 @@ def filter_vars( attrs=self._obj.attrs, ) + # ---------- def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray """Return dict representing the position for maximum value.""" assert isinstance(self._obj, xr.DataArray) @@ -2908,9 +2911,24 @@ def __init__(self, xarray_obj: xr.Dataset) -> None: self._obj = xarray_obj def eval(self, *args: Incomplete, **kwargs: Incomplete) -> xr.DataArray: + """[TODO:summary]. + + Args: + args: [TODO:description] + kwargs: [TODO:description] + + Returns: + [TODO:description] + + TODO: Need Reivision (It does not work.) + """ return self._obj.results.G.map(lambda x: x.eval(*args, **kwargs)) def show(self) -> None: + """[TODO:summary]. + + TODO: Need Revision (It does not work) + """ from .plotting.fit_tool import fit_tool fit_tool(self._obj) @@ -2968,6 +2986,16 @@ def mean_square_error(self) -> xr.DataArray: assert isinstance(self._obj, xr.Dataset) return self._obj.results.F.mean_square_error() + @property + def parameter_names(self) -> set[str]: + """Alias for `ARPESFitToolsAccessor.parameter_names`. + + Returns: + A set of all the parameter names used in a curve fit. + """ + assert isinstance(self._obj, xr.Dataset) + return self._obj.results.F.parameter_names + def p(self, param_name: str) -> xr.DataArray: """Alias for `ARPESFitToolsAccessor.p`. diff --git a/tests/test_curve_fitting.py b/tests/test_curve_fitting.py index 9ba1a8e8..692323d7 100644 --- a/tests/test_curve_fitting.py +++ b/tests/test_curve_fitting.py @@ -16,7 +16,7 @@ def test_broadcast_fitting() -> None: fit_results = broadcast_model([AffineBroadenedFD], near_ef_rebin, "phi", progress=False) - assert np.abs(fit_results.F.p("a_fd_center").values.mean() + 0.00508) < TOLERANCE + assert np.abs(fit_results.F.p("a_fd_center").mean().item() + 0.00508) < TOLERANCE fit_results = broadcast_model( [AffineBroadenedFD],