Skip to content

Commit

Permalink
🎨 residual should not be copied from data.
Browse files Browse the repository at this point in the history
params (especially provenance) should not be transfered.
  • Loading branch information
arafune committed Apr 29, 2024
1 parent 5ff8d78 commit b9e5345
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
9 changes: 6 additions & 3 deletions src/arpes/analysis/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ def deconvolve_rl(
The Richardson-Lucy deconvolved data.
"""
arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
data_image = arr.values
psf_ = psf.values
im_deconv = richardson_lucy(data_image, psf_, num_iter=n_iterations, filter_epsilon=None)
im_deconv = richardson_lucy(
arr.values,
psf.values,
num_iter=n_iterations,
filter_epsilon=None,
)
return arr.S.with_values(im_deconv)


Expand Down
24 changes: 11 additions & 13 deletions src/arpes/fits/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,19 @@ def broadcast_model( # noqa: PLR0913
broadcast_dims = [broadcast_dims]

logger.debug("Normalizing to spectrum")
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
cs = {}
for dim in broadcast_dims:
cs[dim] = data_array.coords[dim]
cs[dim] = data.coords[dim]

other_axes = set(data_array.dims).difference(set(broadcast_dims))
template = data_array.sum(list(other_axes))
other_axes = set(data.dims).difference(set(broadcast_dims))
template = data.sum(list(other_axes))
template.values = np.ndarray(template.shape, dtype=object)
n_fits = np.prod(np.array(list(template.sizes.values())))
if parallelize is None:
parallelize = bool(n_fits > 20) # noqa: PLR2004

residual = data_array.copy(deep=True)
logger.debug("Copying residual")
residual.values = np.zeros(residual.shape)
residual = xr.DataArray(np.zeros_like(data.values), coords=data.coords, dims=data.dims)

logger.debug("Parsing model")
model = parse_model(model_cls)
Expand All @@ -190,7 +188,7 @@ def broadcast_model( # noqa: PLR0913
serialize = parallelize
assert isinstance(serialize, bool)
fitter = mp_fits.MPWorker(
data=data_array,
data=data,
uncompiled_model=model,
prefixes=prefixes,
params=params,
Expand All @@ -209,7 +207,7 @@ def broadcast_model( # noqa: PLR0913
exe_results = list(
wrap_progress(
pool.imap(fitter, template.G.iter_coords()), # IMapIterator
total=n_fits,
total=int(n_fits),
desc="Fitting on pool...",
),
)
Expand All @@ -219,7 +217,7 @@ def broadcast_model( # noqa: PLR0913
for _, cut_coords in wrap_progress(
template.G.enumerate_iter_coords(),
desc="Fitting",
total=n_fits,
total=int(n_fits),
):
exe_results.append(fitter(cut_coords))

Expand All @@ -241,9 +239,9 @@ def unwrap(result_data: str) -> object: # (Unpickler)
return xr.Dataset(
{
"results": template,
"data": data_array,
"data": data,
"residual": residual,
"norm_residual": residual / data_array,
"norm_residual": residual / data,
},
residual.coords,
)
Expand All @@ -254,7 +252,7 @@ def _fake_wqdm(x: Iterable[int], **kwargs: str | float) -> Iterable[int]:
Args:
x (Iterable[int]): [TODO:description]
kwargs: its a dummy parameter, which is not used.
kwargs: its dummy parameters, not used.
Returns:
Same iterable.
Expand Down
28 changes: 28 additions & 0 deletions src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
Unpack,
)

from lmfit.model import propagate_err
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
Expand Down Expand Up @@ -2488,6 +2489,7 @@ def iterate_axis(
coords_dict = dict(zip(axis_name_or_axes, cut_coords, strict=True))
yield coords_dict, self._obj.sel(coords_dict, method="nearest")

# ---------
def filter_vars(
self,
f: Callable[[Hashable, xr.DataArray], bool],
Expand All @@ -2498,6 +2500,7 @@ def filter_vars(
attrs=self._obj.attrs,
)

# ----------
def argmax_coords(self) -> dict[Hashable, float]: # TODO: [RA] DataArray
"""Return dict representing the position for maximum value."""
assert isinstance(self._obj, xr.DataArray)
Expand Down Expand Up @@ -2908,9 +2911,24 @@ def __init__(self, xarray_obj: xr.Dataset) -> None:
self._obj = xarray_obj

def eval(self, *args: Incomplete, **kwargs: Incomplete) -> xr.DataArray:
"""[TODO:summary].
Args:
args: [TODO:description]
kwargs: [TODO:description]
Returns:
[TODO:description]
TODO: Need Reivision (It does not work.)
"""
return self._obj.results.G.map(lambda x: x.eval(*args, **kwargs))

def show(self) -> None:
"""[TODO:summary].
TODO: Need Revision (It does not work)
"""
from .plotting.fit_tool import fit_tool

fit_tool(self._obj)
Expand Down Expand Up @@ -2968,6 +2986,16 @@ def mean_square_error(self) -> xr.DataArray:
assert isinstance(self._obj, xr.Dataset)
return self._obj.results.F.mean_square_error()

@property
def parameter_names(self) -> set[str]:
"""Alias for `ARPESFitToolsAccessor.parameter_names`.
Returns:
A set of all the parameter names used in a curve fit.
"""
assert isinstance(self._obj, xr.Dataset)
return self._obj.results.F.parameter_names

def p(self, param_name: str) -> xr.DataArray:
"""Alias for `ARPESFitToolsAccessor.p`.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_curve_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_broadcast_fitting() -> None:

fit_results = broadcast_model([AffineBroadenedFD], near_ef_rebin, "phi", progress=False)

assert np.abs(fit_results.F.p("a_fd_center").values.mean() + 0.00508) < TOLERANCE
assert np.abs(fit_results.F.p("a_fd_center").mean().item() + 0.00508) < TOLERANCE

fit_results = broadcast_model(
[AffineBroadenedFD],
Expand Down

0 comments on commit b9e5345

Please sign in to comment.