Skip to content

Commit

Permalink
Merge pull request diffpy#191 from yucongalicechen/angle-test
Browse files Browse the repository at this point in the history
test function for get_angle_index
  • Loading branch information
sbillinge authored Dec 7, 2024
2 parents e6827a6 + 3dac06a commit 3aef608
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
23 changes: 23 additions & 0 deletions news/array_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* function to return the index of the closest value to the specified value in an array.

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
32 changes: 23 additions & 9 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,29 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None):
array = np.linspace(begin, end, n_steps)
return array

def get_angle_index(self, angle):
count = 0
for i, target in enumerate(self.angles):
if angle == target:
return i
else:
count += 1
if count >= len(self.angles):
raise IndexError(f"WARNING: no angle {angle} found in angles list")
def get_array_index(self, value, xtype=None):
"""
returns the index of the closest value in the array associated with the specified xtype
Parameters
----------
xtype str
the xtype used to access the array
value float
the target value to search for
Returns
-------
the index of the value in the array
"""

if xtype is None:
xtype = self.input_xtype
array = self.on_xtype(xtype)[0]
if len(array) == 0:
raise ValueError(f"The '{xtype}' array is empty. Please ensure it is initialized.")
i = (np.abs(array - value)).argmin()
return i

def _set_xarrays(self, xarray, xtype):
self._all_arrays = np.empty(shape=(len(xarray), 4))
Expand Down
26 changes: 26 additions & 0 deletions tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -212,6 +213,31 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte
return np.allclose(actual_array, expected_array)


params_index = [
# UC1: exact match
([4 * np.pi, np.array([30.005, 60]), np.array([1, 2]), "tth", "tth", 30.005], [0]),
# UC2: target value lies in the array, returns the (first) closest index
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 45], [0]),
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "q", 0.25], [0]),
# UC3: target value out of the range, returns the closest index
([4 * np.pi, np.array([0.25, 0.5, 0.71]), np.array([1, 2, 3]), "q", "q", 0.1], [0]),
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 63], [1]),
]


@pytest.mark.parametrize("inputs, expected", params_index)
def test_get_array_index(inputs, expected):
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
actual = test.get_array_index(value=inputs[5], xtype=inputs[4])
assert actual == expected[0]


def test_get_array_index_bad():
test = DiffractionObject(wavelength=2 * np.pi, xarray=np.array([]), yarray=np.array([]), xtype="tth")
with pytest.raises(ValueError, match=re.escape("The 'tth' array is empty. Please ensure it is initialized.")):
test.get_array_index(value=30)


def test_dump(tmp_path, mocker):
x, y = np.linspace(0, 5, 6), np.linspace(0, 5, 6)
directory = Path(tmp_path)
Expand Down

0 comments on commit 3aef608

Please sign in to comment.