Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for G.*_meshgrid #60

Merged
merged 24 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/arpes/plotting/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ def plot_data_to_bz2d( # noqa: PLR0913
c, s = np.cos(rotate), np.sin(rotate)
rotation = np.array([(c, -s), (s, c)])

raveled = raveled.G.transform_coords(dims, rotation)
raveled = raveled.G.transform_meshgrid(dims, rotation)

if scale is not None:
raveled = raveled.G.scale_coords(dims, scale)
raveled = raveled.G.scale_meshgrid(dims, scale)

if shift is not None:
raveled = raveled.G.shift_coords(dims, shift)
raveled = raveled.G.shift_meshgrid(dims, shift)

copied = data_array.values.copy()

Expand Down
2 changes: 1 addition & 1 deletion src/arpes/plotting/movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def plot_movie( # noqa: PLR0913
"""
figsize = figsize or (7.0, 7.0)
data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
fig, ax = fig_ax if fig_ax else plt.subplots(figsize=figsize)
fig, ax = fig_ax or plt.subplots(figsize=figsize)

assert isinstance(ax, Axes)
assert isinstance(fig, Figure)
Expand Down
25 changes: 12 additions & 13 deletions src/arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@
from xarray.core.coordinates import DataArrayCoordinates, DatasetCoordinates

import arpes
import arpes.constants
import arpes.utilities.math
from arpes.constants import TWO_DIMENSION

from ._typing import HighSymmetryPoints, MPLPlotKwargs
Expand Down Expand Up @@ -2661,12 +2659,12 @@ def filter_vars(
attrs=self._obj.attrs,
)

def shift_coords(
def shift_meshgrid(
self,
dims: tuple[str, ...],
shift: NDArray[np.float64] | float,
) -> xr.Dataset:
"""Shifts the coordinates and returns a new dataset with the shifted coordinates.
"""Shifts the meshgrid and returns a new dataset with the shifted meshgrid.

Args:
dims (tuple[str, ...]): The list of dimensions whose coordinates will be shifted.
Expand All @@ -2683,7 +2681,7 @@ def shift_coords(
- Add tests.
"""
if not isinstance(shift, np.ndarray):
shift = np.ones((len(dims),)) * shift
shift: NDArray[np.float64] = np.ones((len(dims),)) * shift

def transform(data: NDArray[np.float64]) -> NDArray[np.float64]:
new_shift: NDArray[np.float64] = shift
Expand All @@ -2692,14 +2690,14 @@ def transform(data: NDArray[np.float64]) -> NDArray[np.float64]:

return data + new_shift

return self.transform_coords(dims, transform)
return self.transform_meshgrid(dims, transform)

def scale_coords(
def scale_meshgrid(
self,
dims: tuple[str, ...],
scale: float | NDArray[np.float64],
) -> xr.Dataset:
"""Scales the coordinates and returns a new dataset with the scaled coordinates.
"""Scales the meshgrid and returns a new dataset with the scaled meshgrid.

Args:
dims (tuple[str, ...]): The list of dimensions whose coordinates will be scaled.
Expand All @@ -2718,14 +2716,17 @@ def scale_coords(
elif len(scale.shape) == 1:
scale = np.diag(scale)

return self.transform_coords(dims, scale)
return self.transform_meshgrid(dims, scale)

def transform_coords(
def transform_meshgrid(
self,
dims: Collection[str],
transform: NDArray[np.float64] | Callable,
) -> xr.Dataset:
"""Transforms the given coordinate values according to an arbitrary function.
"""Transforms the given coordinate values in **meshgrid** by an arbitrary function.

This method is applicable to a specific Dataset (assuming the return value of G.meshgrid)
and is not very versatile.

The transformation should either be a function from a len(dims) x size of raveled coordinate
array to len(dims) x size of raveled_coordinate array or a linear transformation as a matrix
Expand All @@ -2738,8 +2739,6 @@ def transform_coords(
Returns:
An identical valued array over new coordinates.

Todo:
Test
"""
assert isinstance(self._obj, xr.Dataset)
as_ndarray = np.stack([self._obj.data_vars[d].values for d in dims], axis=-1)
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import arpes.config
import arpes.endstations
from arpes.example_data.mock import build_mock_tarpes
from arpes.io import example_data
from tests.utils import cache_loader

Expand Down
50 changes: 49 additions & 1 deletion tests/test_xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def test_G_shift(
)

def test_G_meshgrid(self, dataarray_cut: xr.DataArray) -> None:
"""Test for G.meshgrid."""
"""Test for G.meshgrid, G.scale_meshgrid, G.shift_meshgrid."""
small_region = dataarray_cut.sel({"eV": slice(-0.01, 0.0), "phi": slice(0.40, 0.42)})
meshgrid_results = small_region.G.meshgrid()
np.testing.assert_allclose(
Expand Down Expand Up @@ -466,6 +466,54 @@ def test_G_ravel(self, dataarray_cut: xr.DataArray) -> None:
np.testing.assert_allclose(ravel_["data"], np.array([467, 472, 464, 458, 438]))


class TestGeneralforDataset:
"""Test class for GenericDatasetAccessor."""

def test_G_meshgrid_operation(self, dataarray_cut: xr.DataArray):
"""Test G.scale_meshgrid and G.shift_meshgrid, and transform_meshgrid."""
small_region = dataarray_cut.sel({"eV": slice(-0.01, 0.0), "phi": slice(0.40, 0.42)})
meshgrid_set = small_region.G.meshgrid(as_dataset=True)
shifted_meshgrid = meshgrid_set.G.shift_meshgrid(("phi",), -0.2)
np.testing.assert_allclose(
shifted_meshgrid["phi"][1].values,
np.array(
[
0.20142573,
0.20317106,
0.20491639,
0.20666172,
0.20840704,
0.21015237,
0.2118977,
0.21364303,
0.21538836,
0.21713369,
0.21887902,
],
),
)

scaled_meshgrid = meshgrid_set.G.scale_meshgrid(("eV",), 1.5)
np.testing.assert_allclose(
scaled_meshgrid["eV"][-1].values,
np.array(
[
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
-1.155e-07,
],
),
)


class TestAngleUnitforDataArray:
"""Test class for angle_unit for DataArray."""

Expand Down
Loading