Skip to content

Commit

Permalink
Merge pull request diffpy#214 from bobleesj/test-func-std
Browse files Browse the repository at this point in the history
Remove unused test util functions for comparing two dicts, use `__eq__` to compare DiffractionObjects
  • Loading branch information
sbillinge authored Dec 12, 2024
2 parents 3efd93b + 73b2f46 commit fce1a32
Showing 1 changed file with 3 additions and 45 deletions.
48 changes: 3 additions & 45 deletions tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,6 @@

from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject


def compare_dicts(dict1, dict2):
assert dict1.keys() == dict2.keys(), "Keys mismatch"
for key in dict1:
val1, val2 = dict1[key], dict2[key]
if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray):
assert np.allclose(val1, val2), f"Arrays for key '{key}' differ"
elif isinstance(val1, np.float64) and isinstance(val2, np.float64):
assert np.isclose(val1, val2), f"Float64 values for key '{key}' differ"
else:
assert val1 == val2, f"Values for key '{key}' differ: {val1} != {val2}"


def dicts_equal(dict1, dict2):
equal = True
print("")
print(dict1)
print(dict2)
if not dict1.keys() == dict2.keys():
equal = False
for key in dict1:
val1, val2 = dict1[key], dict2[key]
if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray):
if not np.allclose(val1, val2):
equal = False
elif isinstance(val1, list) and isinstance(val2, list):
if not val1.all() == val2.all():
equal = False
elif isinstance(val1, np.float64) and isinstance(val2, np.float64):
if not np.isclose(val1, val2):
equal = False
else:
if not val1 == val2:
equal = False
return equal


params = [
( # Default
{},
Expand Down Expand Up @@ -193,14 +156,9 @@ def dicts_equal(dict1, dict2):

@pytest.mark.parametrize("inputs1, inputs2, expected", params)
def test_diffraction_objects_equality(inputs1, inputs2, expected):
diffraction_object1 = DiffractionObject(**inputs1)
diffraction_object2 = DiffractionObject(**inputs2)
# diffraction_object1_attributes = [key for key in diffraction_object1.__dict__ if not key.startswith("_")]
# for i, attribute in enumerate(diffraction_object1_attributes):
# setattr(diffraction_object1, attribute, inputs1[i])
# setattr(diffraction_object2, attribute, inputs2[i])
print(dicts_equal(diffraction_object1.__dict__, diffraction_object2.__dict__), expected)
assert dicts_equal(diffraction_object1.__dict__, diffraction_object2.__dict__) == expected
do_1 = DiffractionObject(**inputs1)
do_2 = DiffractionObject(**inputs2)
assert (do_1 == do_2) == expected


def test_on_xtype():
Expand Down

0 comments on commit fce1a32

Please sign in to comment.