From fa846c57103f2246c68c1e755775211ee04be933 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 2 Feb 2024 10:04:26 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints=20&=20a?= =?UTF-8?q?dd=20test=20for=20"F".?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/preparation/coord_preparation.py | 20 +++++++++++--------- arpes/preparation/hemisphere_preparation.py | 7 ++++--- arpes/preparation/tof_preparation.py | 6 ++++-- arpes/provenance.py | 3 +++ arpes/xarray_extensions.py | 11 ----------- tests/test_curve_fitting.py | 14 +++++++++++++- tests/test_xarray_extensions.py | 18 +++++++++++++++++- 7 files changed, 52 insertions(+), 27 deletions(-) diff --git a/arpes/preparation/coord_preparation.py b/arpes/preparation/coord_preparation.py index 04adc8c3..ca4d6879 100644 --- a/arpes/preparation/coord_preparation.py +++ b/arpes/preparation/coord_preparation.py @@ -26,11 +26,11 @@ def disambiguate_coordinates( and so refers to a different energy range. """ coords_set = collections.defaultdict(list) - for d in datasets: - assert isinstance(d, xr.DataArray) + for spectrum in datasets: + assert isinstance(spectrum, xr.DataArray) for c in possibly_clashing_coordinates: - if c in d.coords: - coords_set[c].append(d.coords[c]) + if c in spectrum.coords: + coords_set[c].append(spectrum.coords[c]) conflicted = [] for c in possibly_clashing_coordinates: @@ -46,10 +46,12 @@ def disambiguate_coordinates( conflicted.append(c) after_deconflict = [] - for d in datasets: - assert isinstance(d, xr.DataArray) - spectrum_name = next(iter(d.data_vars.keys())) - to_rename = {name: str(name) + "-" + spectrum_name for name in d.dims if name in conflicted} - after_deconflict.append(d.rename(to_rename)) + for spectrum in datasets: + assert isinstance(spectrum, xr.DataArray) + spectrum_name = next(iter(spectrum.data_vars.keys())) + to_rename = { + name: str(name) + "-" + spectrum_name for name in spectrum.dims if name in conflicted + } + after_deconflict.append(spectrum.rename(to_rename)) return after_deconflict diff --git a/arpes/preparation/hemisphere_preparation.py b/arpes/preparation/hemisphere_preparation.py index 1efb5f99..1bcbff37 100644 --- a/arpes/preparation/hemisphere_preparation.py +++ b/arpes/preparation/hemisphere_preparation.py @@ -1,4 +1,5 @@ """Data prep routines for hemisphere data.""" + from __future__ import annotations from itertools import pairwise @@ -40,8 +41,8 @@ def stitch_maps( for i, (lower, higher) in enumerate(pairwise(coord1)): if higher > first_repair_coordinate: break - assert isinstance(i, int) - delta_low, delta_high = lower - first_repair_coordinate, higher - first_repair_coordinate + assert isinstance(i, int) + delta_low, delta_high = lower - first_repair_coordinate, higher - first_repair_coordinate if abs(delta_low) < abs(delta_high): delta = delta_low else: @@ -55,7 +56,7 @@ def stitch_maps( good_data_slice = {} good_data_slice[dimension] = slice(None, i) - selected = arr.isel(**good_data_slice) + selected = arr.isel(good_data_slice) selected.attrs.clear() shifted_repair_map.attrs.clear() concatted = xr.concat([selected, shifted_repair_map], dim=dimension) diff --git a/arpes/preparation/tof_preparation.py b/arpes/preparation/tof_preparation.py index 9b513c09..51f7ea3c 100644 --- a/arpes/preparation/tof_preparation.py +++ b/arpes/preparation/tof_preparation.py @@ -1,4 +1,5 @@ """Data prep routines for time-of-flight data.""" + from __future__ import annotations # noqa: I001 import math @@ -14,6 +15,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from numpy.typing import NDArray + __all__ = [ "build_KE_coords_to_time_pixel_coords", "build_KE_coords_to_time_coords", @@ -150,8 +153,7 @@ def KE_coords_to_time_pixel_coords( def build_KE_coords_to_time_coords( - dataset: xr.Dataset, - interpolation_axis: Sequence[float], + dataset: xr.Dataset, interpolation_axis: NDArray[np.float_] ) -> Callable[..., tuple[xr.DataArray]]: """Constructs a coordinate conversion function from kinetic energy to time coords. diff --git a/arpes/provenance.py b/arpes/provenance.py index 0b6194d8..4511a6be 100644 --- a/arpes/provenance.py +++ b/arpes/provenance.py @@ -75,6 +75,9 @@ class PROVENANCE(TypedDict, total=False): correction: list[NDArray[np.float_]] # fermi_edge_correction # dims: Sequence[str] + # + old_axis: str + new_axis: str def attach_id(data: xr.DataArray | xr.Dataset) -> None: diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 86a47369..b316f565 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -2127,10 +2127,6 @@ def switch_energy_notation(self, nonlinear_order: int = 1) -> None: elif self.hv is not None and self.energy_notation == "Kinetic": self._obj.coords["eV"] = self._obj.coords["eV"] - nonlinear_order * self.hv self._obj.attrs["energy_notation"] = "Binding" - else: - msg = "Cannot determine the current enegy notation.\n" - msg += "You should set attrs['energy_notation'] = 'Kinetic' or 'Binding'" - raise RuntimeError(msg) def corrected_angle_by( self, @@ -3374,7 +3370,6 @@ def spectrum(self) -> xr.DataArray: ToDo: Need test """ - # spectrum = None <== CHECK ME! if "spectrum" in self._obj.data_vars: spectrum = self._obj.spectrum elif "raw" in self._obj.data_vars: @@ -3394,8 +3389,6 @@ def spectrum(self) -> xr.DataArray: else: msg = "No spectrum found" raise RuntimeError(msg) - if spectrum is not None and "df" not in spectrum.attrs: - spectrum.attrs["df"] = self._obj.attrs.get("df", None) return spectrum @property @@ -3588,10 +3581,6 @@ def switch_energy_notation(self, nonlinear_order: int = 1) -> None: self._obj.attrs["energy_notation"] = "Binding" for spectrum in self._obj.data_vars.values(): spectrum.attrs["energy_notation"] = "Binding" - else: - msg = "Cannot determine the current enegy notation.\n" - msg += "You should set attrs['energy_notation'] = 'Kinetic' or 'Binding'" - raise RuntimeError(msg) @property def angle_unit(self) -> Literal["Degrees", "Radians"]: diff --git a/tests/test_curve_fitting.py b/tests/test_curve_fitting.py index d9314bc3..df3b9b11 100644 --- a/tests/test_curve_fitting.py +++ b/tests/test_curve_fitting.py @@ -15,6 +15,18 @@ def test_broadcast_fitting() -> None: near_ef = cut.isel(phi=slice(80, 120)).sel(eV=slice(-0.2, 0.1)) near_ef = rebin(near_ef, phi=5) - fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi") + fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi", progress=False) assert np.abs(fit_results.F.p("a_fd_center").values.mean() + 0.00506558) < TOLERANCE + + fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi", progress=True) + assert fit_results.results.F.parameter_names == { + "a_const_bkg", + "a_conv_width", + "a_fd_center", + "a_fd_width", + "a_lin_bkg", + "a_offset", + } + assert fit_results.F.broadcast_dimensions == ["phi"] + assert fit_results.F.fit_dimensions == ["eV"] diff --git a/tests/test_xarray_extensions.py b/tests/test_xarray_extensions.py index 9cd942e3..2eb29f61 100644 --- a/tests/test_xarray_extensions.py +++ b/tests/test_xarray_extensions.py @@ -19,9 +19,23 @@ def dataarray_cut() -> xr.DataArray: return example_data.cut.spectrum +@pytest.fixture() +def xps_map() -> xr.Dataset: + """A fixture for loading example_data.xps.""" + return example_data.nano_xps + + class TestforProperties: """Test class for Array Dataset properties.""" + def test_degrees_of_freedom_dims(self, xps_map: xr.Dataset) -> None: + """Test for degrees_of_freedom.""" + assert xps_map.S.spectrum_degrees_of_freedom == {"eV"} + assert xps_map.S.scan_degrees_of_freedom == {"x", "y"} + + def test_is_functions(self, xps_map: xr.Dataset) -> None: + assert xps_map.S.is_spatial + def test_find_spectrum_energy_edges(self, dataarray_cut: xr.DataArray) -> None: """Test for find_spectrum_energy_edges.""" np.testing.assert_array_almost_equal( @@ -218,11 +232,13 @@ def test_switch_energy_notation( dataset_cut: xr.Dataset, ) -> None: """Test for switch energy notation.""" + # Test for DataArray dataarray_cut.S.switch_energy_notation() assert dataarray_cut.S.energy_notation == "Kinetic" dataarray_cut.S.switch_energy_notation() assert dataarray_cut.S.energy_notation == "Binding" - # + + # Test for Dataset dataset_cut.S.switch_energy_notation() assert dataset_cut.S.energy_notation == "Kinetic" dataset_cut.S.switch_energy_notation()