Skip to content

Commit

Permalink
💬 Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 13, 2024
1 parent c1998ff commit 3175015
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 256 deletions.
11 changes: 4 additions & 7 deletions src/arpes/analysis/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
if we need to. I doubt that this is necessary and don't mind the copied code too much at the
present.
"""

from __future__ import annotations

import numpy as np
Expand Down Expand Up @@ -60,12 +61,9 @@ def align2d(a: xr.DataArray, b: xr.DataArray, *, subpixel: bool = True) -> tuple

y, x = true_y, true_x

y = 1.0 * y - a.values.shape[0] / 2.0
x = 1.0 * x - a.values.shape[1] / 2.0

return (
y * a.G.stride(generic_dim_names=False)[a.dims[0]],
x * a.G.stride(generic_dim_names=False)[a.dims[1]],
(float(y) - a.values.shape[0] / 2.0) * a.G.stride(generic_dim_names=False)[a.dims[0]],
(float(x) - a.values.shape[1] / 2.0) * a.G.stride(generic_dim_names=False)[a.dims[1]],
)


Expand Down Expand Up @@ -93,8 +91,7 @@ def align1d(a: xr.DataArray, b: xr.DataArray, *, subpixel: bool = True) -> float
mod = QuadraticModel().guess_fit(marg)
x = x + -mod.params["b"].value / (2 * mod.params["a"].value)

x = 1.0 * x - a.values.shape[0] / 2.0
return x * a.G.stride(generic_dim_names=False)[a.dims[0]]
return (float(x) - a.values.shape[0] / 2.0) * a.G.stride(generic_dim_names=False)[a.dims[0]]


def align(a: xr.DataArray, b: xr.DataArray, **kwargs: bool) -> tuple[float, float] | float:
Expand Down
193 changes: 9 additions & 184 deletions src/arpes/analysis/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@

from __future__ import annotations

import contextlib
import warnings
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

import numpy as np
import scipy
import scipy.ndimage
import xarray as xr
from tqdm.notebook import tqdm
from skimage.restoration import richardson_lucy

from arpes.constants import TWO_DIMENSION
import arpes.xarray_extensions # noqa: F401
from arpes.fits.fit_models.functional_forms import gaussian
from arpes.provenance import update_provenance
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
from collections.abc import Iterable

from _typeshed import Incomplete
from numpy.typing import NDArray


Expand Down Expand Up @@ -83,170 +78,24 @@ def deconvolve_ice(
@update_provenance("Lucy Richardson Deconvolution")
def deconvolve_rl(
data: xr.DataArray,
psf: xr.DataArray | None = None,
psf: xr.DataArray,
n_iterations: int = 10,
axis: str = "",
sigma: float = 0,
mode: Literal["reflect", "constant", "nearest", "mirror", "wrap"] = "reflect",
*,
progress: bool = True,
) -> xr.DataArray:
"""Deconvolves data by a given point spread function using the Richardson-Lucy (RL) method.
Args:
data: input data
axis
mode: pass to ndimage.convolve
sigma
progress
psf: for 1d, if not specified, must specify axis and sigma
psf: The point spread function.
n_iterations: the number of convolutions to use for the fit
Returns:
The Richardson-Lucy deconvolved data.
"""
arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)

if psf is None and axis != "" and sigma != 0:
# if no psf is provided and we have the information to make a 1d one
# note: this assumes gaussian psf
psf = make_psf1d(data=arr, dim=axis, sigma=sigma)

if len(data.dims) > 1:
if not axis:
# perform one-dimensional deconvolution of multidimensional data

# support for progress bars
def wrap_progress(
x: Iterable[int],
*args: Incomplete,
**kwargs: Incomplete,
) -> Iterable[int]:
if args:
for arg in args:
warnings.warn(
f"unused args is set in deconvolution.py/wrap_progress: {arg}",
stacklevel=2,
)
if kwargs:
for k, v in kwargs.items():
warnings.warn(
f"unused args is set in deconvolution.py/wrap_progress: {k}: {v}",
stacklevel=2,
)
return x

if progress:
wrap_progress = tqdm

# dimensions over which to iterate
other_dim = list(data.dims)
other_dim.remove(axis)

if len(other_dim) == 1:
# two-dimensional data
other_dim = other_dim[0]
result = arr.copy(deep=True).transpose(
other_dim,
axis,
)
# not sure why the dims only seems to work in this order.
# seems like I should be able to swap it to (axis,other_dim)
# and also change the data collection to result[x_ind,y_ind],
# but this gave different results

for i, (_, iteration) in wrap_progress(
enumerate(arr.G.iterate_axis(other_dim)),
desc="Iterating " + other_dim,
total=len(arr[other_dim]),
): # TODO: tidy this gross-looking loop
# indices of data being deconvolved
x_ind = xr.DataArray(list(range(len(arr[axis]))), dims=[axis])
y_ind = xr.DataArray([i] * len(x_ind), dims=[other_dim])
# perform deconvolution on this one-dimensional piece
deconv = deconvolve_rl(
data=iteration,
psf=psf,
n_iterations=n_iterations,
axis="",
mode=mode,
)
# build results out of these pieces
result[y_ind, x_ind] = deconv.values
elif len(other_dim) == TWO_DIMWENSION:
# three-dimensional data
result = arr.copy(deep=True).transpose(*other_dim, axis)
# not sure why the dims only seems to work in this order.
# eems like I should be able to swap it to (axis,*other_dim) and also change the
# data collection to result[x_ind,y_ind,z_ind], but this gave different results
for i, (_od0, iteration0) in wrap_progress(
enumerate(arr.G.iterate_axis(other_dim[0])),
desc="Iterating " + str(other_dim[0]),
total=len(arr[other_dim[0]]),
): # TODO: tidy this gross-looking loop
for j, (_od1, iteration1) in wrap_progress(
enumerate(iteration0.G.iterate_axis(other_dim[1])),
desc="Iterating " + str(other_dim[1]),
total=len(arr[other_dim[1]]),
leave=False,
): # TODO: tidy this gross-looking loop
# indices of data being deconvolved
x_ind = xr.DataArray(list(range(len(arr[axis]))), dims=[axis])
y_ind = xr.DataArray([i] * len(x_ind), dims=[other_dim[0]])
z_ind = xr.DataArray([j] * len(x_ind), dims=[other_dim[1]])
# perform deconvolution on this one-dimensional piece
deconv = deconvolve_rl(
data=iteration1,
psf=psf,
n_iterations=n_iterations,
axis="",
mode=mode,
)
# build results out of these pieces
result[y_ind, z_ind, x_ind] = deconv.values
elif len(other_dim) >= TWO_DIMENSION + 1:
# four- or higher-dimensional data
# TODO: find way to compactify the different dimensionalities rather than having
# separate code
msg = "high-dimensional data not yet supported"
raise NotImplementedError(msg)
elif not axis:
# crude attempt to perform multidimensional deconvolution.
# not clear if this is currently working
# TODO: may be able to do this as a sequence of one-dimensional deconvolutions, assuming
# that the psf is separable (which I think it should be, if we assume it is a
# multivariate gaussian with principle axes aligned with the dimensions)
msg = "multi-dimensional convolutions not yet supported"
raise NotImplementedError(msg)

if not isinstance(arr, np.ndarray):
arr = arr.values

u = [arr]

for i in range(n_iterations):
c = scipy.ndimage.convolve(u[-1], psf, mode=mode)
u.append(u[-1] * scipy.ndimage.convolve(arr / c, np.flip(psf, None), mode=mode))
# careful about which axis (axes) to flip here...!
# need to explicitly specify for some versions of numpy

result = u[-1]
else: # data.dims == 1
if not isinstance(arr, np.ndarray):
arr = arr.values
u = [arr]
for _ in range(n_iterations):
c = scipy.ndimage.convolve(u[-1], psf, mode=mode)
u.append(u[-1] * scipy.ndimage.convolve(arr / c, np.flip(psf, 0), mode=mode))
# not yet tested to ensure flip correct for asymmetric psf
# note: need to explicitly specify axis number in np.flip in lower versions of numpy
if isinstance(data, np.ndarray):
result = u[-1].copy()
else:
result = data.copy(deep=True)
result.values = u[-1]
with contextlib.suppress(Exception):
return result.transpose(*arr.dims)
data_image = arr.values
psf_ = psf.values
im_deconv = richardson_lucy(data_image, psf_, num_iter=n_iterations, filter_epsilon=None)
return arr.S.with_values(im_deconv)


@update_provenance("Make 1D-Point Spread Function")
Expand Down Expand Up @@ -283,27 +132,3 @@ def make_psf(data: xr.DataArray, sigmas: dict[str, float]) -> xr.DataArray:
Returns:
The PSF to use.
"""
raise NotImplementedError

arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
dims = arr.dims

psf = arr.copy(deep=True) * 0 + 1

for dim in dims:
other_dims = list(arr.dims)
other_dims.remove(dim)

psf1d = arr.copy(deep=True) * 0 + 1
for od in other_dims:
psf1d = psf1d[{od: 0}]

if sigmas[dim] == 0:
# TODO: may need to do subpixel correction for when the dimension has an even length
psf1d = psf1d * 0
psf1d[{dim: len(psf1d.coords[dim]) / 2}] = 1
else:
psf1d = psf1d * gaussian(psf1d.coords[dim], np.mean(psf1d.coords[dim]), sigmas[dim])

psf = psf * psf1d
return psf
11 changes: 5 additions & 6 deletions src/arpes/analysis/gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def determine_broadened_fermi_distribution(


@update_provenance("Normalize By Fermi Dirac")
def normalize_by_fermi_dirac(
def normalize_by_fermi_dirac( # noqa: PLR0913
data: DataType,
reference_data: DataType | None = None,
broadening: float = 0,
Expand Down Expand Up @@ -142,7 +142,7 @@ def normalize_by_fermi_dirac(
if (not temperature_axis) and "temp" in data.dims:
temperature_axis = "temp"

transpose_order = list(data.dims)
transpose_order: list[str] = [str(dim) for dim in data.dims]
transpose_order.remove("eV")

if temperature_axis:
Expand Down Expand Up @@ -190,8 +190,7 @@ def _shift_energy_interpolate(
data: xr.DataArray,
shift: xr.DataArray | None = None,
) -> xr.DataArray:
if not isinstance(data, xr.DataArray):
data = normalize_to_spectrum(data)
data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
data_arr = data.S.transpose_to_front("eV")

new_data = data_arr.copy(deep=True)
Expand All @@ -211,6 +210,7 @@ def _shift_energy_interpolate(
shift = shift - stride * n_strides

new_axis = new_axis + shift
assert shift is not None

weight = float(shift / stride)
new_values = new_values + data_arr.values * (1 - weight)
Expand Down Expand Up @@ -249,8 +249,7 @@ def symmetrize(
Returns:
The symmetrized data.
"""
if not isinstance(data, xr.DataArray):
data = normalize_to_spectrum(data)
data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
data = data.S.transpose_to_front("eV")

if subpixel or full_spectrum:
Expand Down
3 changes: 2 additions & 1 deletion src/arpes/analysis/path.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains routines used to do path selections and manipulations on a dataset."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -54,7 +55,7 @@ def as_vec(ds: xr.Dataset) -> NDArray[np.float_]:
return np.array([ds[k].item() for k in order])

def distance(a: xr.Dataset, b: xr.Dataset) -> float:
return np.linalg.norm((as_vec(a) - as_vec(b)) * scaling)
return float(np.linalg.norm((as_vec(a) - as_vec(b)) * scaling))

length = 0
for idx_low, idx_high in zip(path.index.values, path.index[1:].values, strict=False):
Expand Down
Loading

0 comments on commit 3175015

Please sign in to comment.