Skip to content

Commit

Permalink
Merge branch 'daredevil' of github.com:arafune/arpes into daredevil
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 9, 2024
2 parents 88b3c1d + d9c9af7 commit 4d59a65
Show file tree
Hide file tree
Showing 56 changed files with 467 additions and 359 deletions.
2 changes: 1 addition & 1 deletion arpes/analysis/band_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


def fit_for_effective_mass(
data: XrTypes,
data: xr.DataArray,
fit_kwargs: dict | None = None,
) -> float:
"""Fits for the effective mass in a piece of data.
Expand Down
8 changes: 4 additions & 4 deletions arpes/analysis/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import xarray as xr
from _typeshed import Incomplete

from arpes._typing import DataType
__all__ = (
"nmf_along",
"pca_along",
Expand All @@ -24,7 +23,7 @@


def decomposition_along(
data: DataType,
data: xr.DataArray,
axes: list[str],
decomposition_cls: type[sklearn.decomposition],
*,
Expand Down Expand Up @@ -69,11 +68,12 @@ def decomposition_along(
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
if len(axes) > 1:
flattened_data: xr.DataArray = normalize_to_spectrum(data).stack(fit_axis=axes)
flattened_data: xr.DataArray = data.stack(fit_axis=axes)
stacked = True
else:
flattened_data = normalize_to_spectrum(data).S.transpose_to_back(axes[0])
flattened_data = data.S.transpose_to_back(axes[0])
stacked = False

if len(flattened_data.dims) != TWO_DIMENSION:
Expand Down
28 changes: 14 additions & 14 deletions arpes/analysis/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

@update_provenance("Approximate Iterative Deconvolution")
def deconvolve_ice(
data: DataType,
data: xr.DataArray,
psf: NDArray[np.float_],
n_iterations: int = 5,
deg: int | None = None,
Expand All @@ -55,8 +55,8 @@ def deconvolve_ice(
Returns:
The deconvoled data in the same format.
"""
arr = normalize_to_spectrum(data).values

data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data).values
arr = data.values
if deg is None:
deg = n_iterations - 3
iteration_steps = list(range(1, n_iterations + 1))
Expand All @@ -73,17 +73,17 @@ def deconvolve_ice(
poly = np.poly1d(coefs)
deconv[t] = poly(0)

if type(data) is np.ndarray:
if isinstance(data, np.ndarray):
result = deconv
else:
result = normalize_to_spectrum(data).copy(deep=True)
result = data.copy(deep=True)
result.values = deconv
return result


@update_provenance("Lucy Richardson Deconvolution")
def deconvolve_rl(
data: DataType,
data: xr.DataArray,
psf: xr.DataArray | None = None,
n_iterations: int = 10,
axis: str = "",
Expand All @@ -106,7 +106,7 @@ def deconvolve_rl(
Returns:
The Richardson-Lucy deconvolved data.
"""
arr = normalize_to_spectrum(data)
arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)

if psf is None and axis != "" and sigma != 0:
# if no psf is provided and we have the information to make a 1d one
Expand Down Expand Up @@ -233,25 +233,25 @@ def wrap_progress(

result = u[-1]
else: # data.dims == 1
if type(arr) is not np.ndarray:
if not isinstance(arr, np.ndarray):
arr = arr.values
u = [arr]
for _ in range(n_iterations):
c = scipy.ndimage.convolve(u[-1], psf, mode=mode)
u.append(u[-1] * scipy.ndimage.convolve(arr / c, np.flip(psf, 0), mode=mode))
# not yet tested to ensure flip correct for asymmetric psf
# note: need to explicitly specify axis number in np.flip in lower versions of numpy
if type(data) is np.ndarray:
if isinstance(data, np.ndarray):
result = u[-1].copy()
else:
result = normalize_to_spectrum(data).copy(deep=True)
result = data.copy(deep=True)
result.values = u[-1]
with contextlib.suppress(Exception):
return result.transpose(*arr.dims)


@update_provenance("Make 1D-Point Spread Function")
def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray:
def make_psf1d(data: xr.DataArray, dim: str, sigma: float) -> xr.DataArray:
"""Produces a 1-dimensional gaussian point spread function for use in deconvolve_rl.
Args:
Expand All @@ -262,7 +262,7 @@ def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray:
Returns:
A one dimensional point spread array.
"""
arr = normalize_to_spectrum(data)
arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
psf = arr.copy(deep=True) * 0 + 1
other_dims = list(arr.dims)
other_dims.remove(dim)
Expand All @@ -272,7 +272,7 @@ def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray:


@update_provenance("Make Point Spread Function")
def make_psf(data: DataType, sigmas: dict[str, float]) -> xr.DataArray:
def make_psf(data: xr.DataArray, sigmas: dict[str, float]) -> xr.DataArray:
"""Produces an n-dimensional gaussian point spread function for use in deconvolve_rl.
Not yet operational.
Expand All @@ -286,7 +286,7 @@ def make_psf(data: DataType, sigmas: dict[str, float]) -> xr.DataArray:
"""
raise NotImplementedError

arr = normalize_to_spectrum(data)
arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
dims = arr.dims

psf = arr.copy(deep=True) * 0 + 1
Expand Down
9 changes: 4 additions & 5 deletions arpes/analysis/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = (
"curvature2d",
Expand Down Expand Up @@ -81,7 +80,7 @@ def _vector_diff(

@update_provenance("Minimum Gradient")
def minimum_gradient(
data: DataType,
data: xr.DataArray,
*,
smooth_fn: Callable[[xr.DataArray], xr.DataArray] | None = None,
delta: DELTA = 1,
Expand All @@ -99,15 +98,15 @@ def warpped_filter(arr: xr.DataArray):
Returns:
The gradient of the original intensity, which enhances the peak position.
"""
arr = normalize_to_spectrum(data)
arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
assert isinstance(arr, xr.DataArray)
smooth_ = _nothing_to_array if smooth_fn is None else smooth_fn
arr = smooth_(arr)
return arr / _gradient_modulus(arr, delta=delta)


@update_provenance("Gradient Modulus")
def _gradient_modulus(data: DataType, *, delta: DELTA = 1) -> xr.DataArray:
def _gradient_modulus(data: xr.DataArray, *, delta: DELTA = 1) -> xr.DataArray:
"""Helper function for minimum gradient.
Args:
Expand All @@ -117,7 +116,7 @@ def _gradient_modulus(data: DataType, *, delta: DELTA = 1) -> xr.DataArray:
Returns: xr.DataArray
[TODO:description]
"""
spectrum = normalize_to_spectrum(data)
spectrum = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
assert isinstance(spectrum, xr.DataArray)
values: NDArray[np.float_] = spectrum.values
gradient_vector = np.zeros(shape=(8, *values.shape))
Expand Down
18 changes: 13 additions & 5 deletions arpes/analysis/gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def determine_broadened_fermi_distribution(
"vary": False,
}

reference_data_array = normalize_to_spectrum(reference_data)
reference_data_array = (
reference_data
if isinstance(reference_data, xr.DataArray)
else normalize_to_spectrum(reference_data)
)

sum_dims = list(reference_data_array.dims)
sum_dims.remove("eV")
Expand Down Expand Up @@ -183,10 +187,12 @@ def normalize_by_fermi_dirac(


def _shift_energy_interpolate(
data: DataType,
data: xr.DataArray,
shift: xr.DataArray | None = None,
) -> xr.DataArray:
data_arr = normalize_to_spectrum(data).S.transpose_to_front("eV")
if not isinstance(data, xr.DataArray):
data = normalize_to_spectrum(data)
data_arr = data.S.transpose_to_front("eV")

new_data = data_arr.copy(deep=True)
new_axis = new_data.coords["eV"]
Expand Down Expand Up @@ -221,7 +227,7 @@ def _shift_energy_interpolate(

@update_provenance("Symmetrize")
def symmetrize(
data: DataType,
data: xr.DataArray,
*,
subpixel: bool = False,
full_spectrum: bool = False,
Expand All @@ -243,7 +249,9 @@ def symmetrize(
Returns:
The symmetrized data.
"""
data = normalize_to_spectrum(data).S.transpose_to_front("eV")
if not isinstance(data, xr.DataArray):
data = normalize_to_spectrum(data)
data = data.S.transpose_to_front("eV")

if subpixel or full_spectrum:
data = _shift_energy_interpolate(data)
Expand Down
12 changes: 6 additions & 6 deletions arpes/analysis/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .filters import gaussian_filter_arr

if TYPE_CHECKING:
from arpes._typing import DataType
from arpes._typing import DataType, XrTypes

__all__ = (
"normalize_by_fermi_distribution",
Expand All @@ -32,7 +32,7 @@

@update_provenance("Fit Fermi Edge")
def fit_fermi_edge(
data: DataType,
data: XrTypes,
energy_range: slice | None = None,
) -> xr.Dataset:
"""Fits a Fermi edge.
Expand All @@ -59,7 +59,7 @@ def fit_fermi_edge(

@update_provenance("Normalized by the 1/Fermi Dirac Distribution at sample temp")
def normalize_by_fermi_distribution(
data: DataType,
data: xr.DataArray,
max_gain: float = 0,
rigid_shift: float = 0,
instrumental_broadening: float = 0,
Expand All @@ -86,7 +86,7 @@ def normalize_by_fermi_distribution(
Returns:
Normalized DataArray
"""
data_array = normalize_to_spectrum(data)
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
if not total_broadening:
distrib = fermi_distribution(
data_array.coords["eV"].values - rigid_shift,
Expand All @@ -113,7 +113,7 @@ def normalize_by_fermi_distribution(

@update_provenance("Symmetrize about axis")
def symmetrize_axis(
data: DataType,
data: XrTypes,
axis_name: str,
flip_axes: list[str] | None = None,
) -> xr.DataArray:
Expand Down Expand Up @@ -153,7 +153,7 @@ def symmetrize_axis(


@update_provenance("Condensed array")
def condense(data: xr.DataArray) -> xr.DataArray:
def condense(data: DataType) -> DataType:
"""Clips the data so that only regions where there is substantial weight are included.
In practice this usually means selecting along the ``eV`` axis, although other selections
Expand Down
5 changes: 2 additions & 3 deletions arpes/analysis/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from _typeshed import Incomplete
from numpy.typing import NDArray

from arpes._typing import DataType

__all__ = (
"polys_to_mask",
Expand Down Expand Up @@ -135,7 +134,7 @@ def apply_mask_to_coords(

@update_provenance("Apply boolean mask to data")
def apply_mask(
data: DataType,
data: xr.DataArray,
mask: dict[str, Incomplete],
replace: float = np.nan,
radius=None,
Expand Down Expand Up @@ -169,7 +168,7 @@ def apply_mask(
Returns:
Data with values masked out.
"""
data_array = normalize_to_spectrum(data)
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
fermi = mask.get("fermi")

if isinstance(mask, dict):
Expand Down
15 changes: 8 additions & 7 deletions arpes/analysis/pocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def pocket_parameters(

@update_provenance("Collect EDCs projected at an angle from pocket")
def radial_edcs_along_pocket(
data: XrTypes,
data: xr.DataArray,
angle: float,
radii: tuple[float, float] = (0.0, 5.0),
n_points: int = 0,
Expand Down Expand Up @@ -113,7 +113,7 @@ def radial_edcs_along_pocket(
A 2D array which has an angular coordinate around the pocket center.
"""
inner_radius, outer_radius = radii
data_array = normalize_to_spectrum(data)
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
fermi_surface_dims = list(data_array.dims)

assert "eV" in fermi_surface_dims
Expand Down Expand Up @@ -158,7 +158,7 @@ def radial_edcs_along_pocket(


def curves_along_pocket(
data: XrTypes,
data: xr.DataArray,
n_points: int = 0,
inner_radius: float = 0.0,
outer_radius: float = 5.0,
Expand All @@ -185,7 +185,7 @@ def curves_along_pocket(
A tuple of two lists. The first list contains the slices and the second
the coordinates of each slice around the pocket center.
"""
data_array = normalize_to_spectrum(data)
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
assert isinstance(data_array, xr.DataArray)
fermi_surface_dims = list(data_array.dims)
if "eV" in fermi_surface_dims:
Expand Down Expand Up @@ -237,7 +237,7 @@ def slice_at_angle(theta: float) -> xr.DataArray:


def find_kf_by_mdc(
slice_data: XrTypes,
slice_data: xr.DataArray,
offset: float = 0,
**kwargs: Incomplete,
) -> float:
Expand All @@ -254,8 +254,9 @@ def find_kf_by_mdc(
Returns:
The fitting Fermi momentum.
"""
if isinstance(slice_data, xr.Dataset):
slice_arr = normalize_to_spectrum(slice_data)
slice_arr = (
slice_data if isinstance(slice_data, xr.DataArray) else normalize_to_spectrum(slice_data)
)

assert isinstance(slice_arr, xr.DataArray)

Expand Down
Loading

0 comments on commit 4d59a65

Please sign in to comment.