Skip to content

Commit

Permalink
💬 Update type hints & add test for "F".
Browse files Browse the repository at this point in the history
  • Loading branch information
arafune committed Feb 2, 2024
1 parent a1c6585 commit fa846c5
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 27 deletions.
20 changes: 11 additions & 9 deletions arpes/preparation/coord_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def disambiguate_coordinates(
and so refers to a different energy range.
"""
coords_set = collections.defaultdict(list)
for d in datasets:
assert isinstance(d, xr.DataArray)
for spectrum in datasets:
assert isinstance(spectrum, xr.DataArray)
for c in possibly_clashing_coordinates:
if c in d.coords:
coords_set[c].append(d.coords[c])
if c in spectrum.coords:
coords_set[c].append(spectrum.coords[c])

conflicted = []
for c in possibly_clashing_coordinates:
Expand All @@ -46,10 +46,12 @@ def disambiguate_coordinates(
conflicted.append(c)

after_deconflict = []
for d in datasets:
assert isinstance(d, xr.DataArray)
spectrum_name = next(iter(d.data_vars.keys()))
to_rename = {name: str(name) + "-" + spectrum_name for name in d.dims if name in conflicted}
after_deconflict.append(d.rename(to_rename))
for spectrum in datasets:
assert isinstance(spectrum, xr.DataArray)
spectrum_name = next(iter(spectrum.data_vars.keys()))
to_rename = {
name: str(name) + "-" + spectrum_name for name in spectrum.dims if name in conflicted
}
after_deconflict.append(spectrum.rename(to_rename))

return after_deconflict
7 changes: 4 additions & 3 deletions arpes/preparation/hemisphere_preparation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data prep routines for hemisphere data."""

from __future__ import annotations

from itertools import pairwise
Expand Down Expand Up @@ -40,8 +41,8 @@ def stitch_maps(
for i, (lower, higher) in enumerate(pairwise(coord1)):
if higher > first_repair_coordinate:
break
assert isinstance(i, int)
delta_low, delta_high = lower - first_repair_coordinate, higher - first_repair_coordinate
assert isinstance(i, int)
delta_low, delta_high = lower - first_repair_coordinate, higher - first_repair_coordinate
if abs(delta_low) < abs(delta_high):
delta = delta_low
else:
Expand All @@ -55,7 +56,7 @@ def stitch_maps(
good_data_slice = {}
good_data_slice[dimension] = slice(None, i)

selected = arr.isel(**good_data_slice)
selected = arr.isel(good_data_slice)
selected.attrs.clear()
shifted_repair_map.attrs.clear()
concatted = xr.concat([selected, shifted_repair_map], dim=dimension)
Expand Down
6 changes: 4 additions & 2 deletions arpes/preparation/tof_preparation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data prep routines for time-of-flight data."""

from __future__ import annotations # noqa: I001

import math
Expand All @@ -14,6 +15,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Sequence

from numpy.typing import NDArray

__all__ = [
"build_KE_coords_to_time_pixel_coords",
"build_KE_coords_to_time_coords",
Expand Down Expand Up @@ -150,8 +153,7 @@ def KE_coords_to_time_pixel_coords(


def build_KE_coords_to_time_coords(
dataset: xr.Dataset,
interpolation_axis: Sequence[float],
dataset: xr.Dataset, interpolation_axis: NDArray[np.float_]
) -> Callable[..., tuple[xr.DataArray]]:
"""Constructs a coordinate conversion function from kinetic energy to time coords.
Expand Down
3 changes: 3 additions & 0 deletions arpes/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class PROVENANCE(TypedDict, total=False):
correction: list[NDArray[np.float_]] # fermi_edge_correction
#
dims: Sequence[str]
#
old_axis: str
new_axis: str


def attach_id(data: xr.DataArray | xr.Dataset) -> None:
Expand Down
11 changes: 0 additions & 11 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,10 +2127,6 @@ def switch_energy_notation(self, nonlinear_order: int = 1) -> None:
elif self.hv is not None and self.energy_notation == "Kinetic":
self._obj.coords["eV"] = self._obj.coords["eV"] - nonlinear_order * self.hv
self._obj.attrs["energy_notation"] = "Binding"
else:
msg = "Cannot determine the current enegy notation.\n"
msg += "You should set attrs['energy_notation'] = 'Kinetic' or 'Binding'"
raise RuntimeError(msg)

def corrected_angle_by(
self,
Expand Down Expand Up @@ -3374,7 +3370,6 @@ def spectrum(self) -> xr.DataArray:
ToDo: Need test
"""
# spectrum = None <== CHECK ME!
if "spectrum" in self._obj.data_vars:
spectrum = self._obj.spectrum
elif "raw" in self._obj.data_vars:
Expand All @@ -3394,8 +3389,6 @@ def spectrum(self) -> xr.DataArray:
else:
msg = "No spectrum found"
raise RuntimeError(msg)
if spectrum is not None and "df" not in spectrum.attrs:
spectrum.attrs["df"] = self._obj.attrs.get("df", None)
return spectrum

@property
Expand Down Expand Up @@ -3588,10 +3581,6 @@ def switch_energy_notation(self, nonlinear_order: int = 1) -> None:
self._obj.attrs["energy_notation"] = "Binding"
for spectrum in self._obj.data_vars.values():
spectrum.attrs["energy_notation"] = "Binding"
else:
msg = "Cannot determine the current enegy notation.\n"
msg += "You should set attrs['energy_notation'] = 'Kinetic' or 'Binding'"
raise RuntimeError(msg)

@property
def angle_unit(self) -> Literal["Degrees", "Radians"]:
Expand Down
14 changes: 13 additions & 1 deletion tests/test_curve_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ def test_broadcast_fitting() -> None:
near_ef = cut.isel(phi=slice(80, 120)).sel(eV=slice(-0.2, 0.1))
near_ef = rebin(near_ef, phi=5)

fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi")
fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi", progress=False)

assert np.abs(fit_results.F.p("a_fd_center").values.mean() + 0.00506558) < TOLERANCE

fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi", progress=True)
assert fit_results.results.F.parameter_names == {
"a_const_bkg",
"a_conv_width",
"a_fd_center",
"a_fd_width",
"a_lin_bkg",
"a_offset",
}
assert fit_results.F.broadcast_dimensions == ["phi"]
assert fit_results.F.fit_dimensions == ["eV"]
18 changes: 17 additions & 1 deletion tests/test_xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,23 @@ def dataarray_cut() -> xr.DataArray:
return example_data.cut.spectrum


@pytest.fixture()
def xps_map() -> xr.Dataset:
"""A fixture for loading example_data.xps."""
return example_data.nano_xps


class TestforProperties:
"""Test class for Array Dataset properties."""

def test_degrees_of_freedom_dims(self, xps_map: xr.Dataset) -> None:
"""Test for degrees_of_freedom."""
assert xps_map.S.spectrum_degrees_of_freedom == {"eV"}
assert xps_map.S.scan_degrees_of_freedom == {"x", "y"}

def test_is_functions(self, xps_map: xr.Dataset) -> None:
assert xps_map.S.is_spatial

def test_find_spectrum_energy_edges(self, dataarray_cut: xr.DataArray) -> None:
"""Test for find_spectrum_energy_edges."""
np.testing.assert_array_almost_equal(
Expand Down Expand Up @@ -218,11 +232,13 @@ def test_switch_energy_notation(
dataset_cut: xr.Dataset,
) -> None:
"""Test for switch energy notation."""
# Test for DataArray
dataarray_cut.S.switch_energy_notation()
assert dataarray_cut.S.energy_notation == "Kinetic"
dataarray_cut.S.switch_energy_notation()
assert dataarray_cut.S.energy_notation == "Binding"
#

# Test for Dataset
dataset_cut.S.switch_energy_notation()
assert dataset_cut.S.energy_notation == "Kinetic"
dataset_cut.S.switch_energy_notation()
Expand Down

0 comments on commit fa846c5

Please sign in to comment.