Skip to content

Commit

Permalink
Fix equality check for floats (#2507)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2507

`same_elements()` wasn't working for `float('nan')`, and it wasn't treating floats with `np.is_close()` like in `object_attribute_dicts_find_unequal_fields()`.

Also, `same_elements` was generally broken.  Example:
{F1676730235}

Reviewed By: saitcakmak

Differential Revision: D58289519

fbshipit-source-id: 6fd2a253968763a8a52956c87cd2a97975a9755f
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 7, 2024
1 parent 40ae984 commit 932746d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 33 deletions.
66 changes: 33 additions & 33 deletions ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,47 @@ def same_elements(list1: List[Any], list2: List[Any]) -> bool:
-- The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same
length, and that one is a subset of the other.
length, and that both are subsets of the other.
"""

if len(list1) != len(list2):
return False

matched = [False for _ in list2]
for item1 in list1:
found = False
for item2 in list2:
if isinstance(item1, np.ndarray) or isinstance(item2, np.ndarray):
if (
isinstance(item1, np.ndarray)
and isinstance(item2, np.ndarray)
and np.array_equal(item1, item2)
):
found = True
break
elif item1 == item2:
found = True
matched_this_item = False
for i, item2 in enumerate(list2):
if not matched[i] and is_ax_equal(item1, item2):
matched[i] = True
matched_this_item = True
break
if not found:
if not matched_this_item:
return False
return all(matched)

return True

# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def is_ax_equal(one_val: Any, other_val: Any) -> bool:
"""Check for equality of two values, handling lists, dicts, dfs, floats,
dates, and numpy arrays. This method and ``same_elements`` function
as a recursive unit.
"""
if isinstance(one_val, list) and isinstance(other_val, list):
return same_elements(one_val, other_val)
elif isinstance(one_val, dict) and isinstance(other_val, dict):
return sorted(one_val.keys()) == sorted(other_val.keys()) and same_elements(
list(one_val.values()), list(other_val.values())
)
elif isinstance(one_val, np.ndarray) and isinstance(other_val, np.ndarray):
return np.array_equal(one_val, other_val, equal_nan=True)
elif isinstance(one_val, datetime):
return datetime_equals(one_val, other_val)
elif isinstance(one_val, float) and isinstance(other_val, float):
return np.isclose(one_val, other_val, equal_nan=True)
elif isinstance(one_val, pd.DataFrame) and isinstance(other_val, pd.DataFrame):
return dataframe_equals(one_val, other_val)
else:
return one_val == other_val


def datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) -> bool:
Expand Down Expand Up @@ -198,25 +215,8 @@ def object_attribute_dicts_find_unequal_fields(
and isinstance(one_val.model, type(other_val.model))
)

elif isinstance(one_val, list):
equal = isinstance(other_val, list) and same_elements(one_val, other_val)
elif isinstance(one_val, dict):
equal = isinstance(other_val, dict) and sorted(one_val.keys()) == sorted(
other_val.keys()
)
equal = equal and same_elements(
list(one_val.values()), list(other_val.values())
)
elif isinstance(one_val, np.ndarray):
equal = np.array_equal(one_val, other_val, equal_nan=True)
elif isinstance(one_val, datetime):
equal = datetime_equals(one_val, other_val)
elif isinstance(one_val, float):
equal = np.isclose(one_val, other_val)
elif isinstance(one_val, pd.DataFrame):
equal = dataframe_equals(one_val, other_val)
else:
equal = one_val == other_val
equal = is_ax_equal(one_val, other_val)

if not equal:
unequal_value[field] = (one_val, other_val)
Expand Down
7 changes: 7 additions & 0 deletions ax/utils/common/tests/test_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ def eq(x, y):
def test_ListsEquals(self) -> None:
self.assertFalse(same_elements([0], [0, 1]))
self.assertFalse(same_elements([1, 0], [0, 2]))
self.assertFalse(same_elements([1, 1], [1, 2]))
self.assertFalse(same_elements([1, 2], [1, 1]))
self.assertFalse(same_elements([1, 1, 2], [1, 2, 2]))
self.assertTrue(same_elements([1, 0], [0, 1]))

def test_ListsEquals_floats(self) -> None:
self.assertTrue(same_elements([0.0], [0.000000000000001]))
self.assertTrue(same_elements([float("nan")], [float("nan")]))

def test_DatetimeEquals(self) -> None:
now = datetime.now()
self.assertTrue(datetime_equals(None, None))
Expand Down

0 comments on commit 932746d

Please sign in to comment.