Skip to content

Commit

Permalink
Explicitly call a approx function
Browse files Browse the repository at this point in the history
Revert "Expose tolerance for float comparison to the user"

This reverts commit d7bda60.

Revert "Add tolerance to ExpectedOutput for floating point values"

This reverts commit 5781c90.
  • Loading branch information
ianmnz committed Feb 21, 2024
1 parent d7bda60 commit bc82d7b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 20 deletions.
70 changes: 54 additions & 16 deletions src/andromede/simulation/output_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
import math
from dataclasses import dataclass, field
from typing import Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast

from andromede.simulation.optimization import SolverAndContext
from andromede.study.data import TimeScenarioIndex
Expand All @@ -41,14 +41,26 @@ class Variable:
_name: str
_value: Dict[TimeScenarioIndex, float] = field(init=False, default_factory=dict)
_size: Tuple[int, int] = field(init=False, default=(0, 0))

ignore: bool = field(default=False, init=False)
rel_tol: float = field(default=1.0e-9, init=False)
abs_tol: float = field(default=0.0, init=False)

def __eq__(self, other: object) -> bool:
if not isinstance(other, OutputValues.Variable):
return NotImplemented
return (self.ignore or other.ignore) or (
self._name == other._name
and self._size == other._size
and self._value == other._value
)

def is_close(
self,
other: "OutputValues.Variable",
*,
rel_tol: float = 1.0e-9,
abs_tol: float = 0.0,
) -> bool:
# From the docs in https://docs.python.org/3/library/math.html#math.isclose
# math.isclose(a, b) returns abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
return (self.ignore or other.ignore) or (
self._name == other._name
and self._size == other._size
Expand All @@ -57,8 +69,8 @@ def __eq__(self, other: object) -> bool:
math.isclose(
self._value[key],
other._value[key],
rel_tol=self.rel_tol,
abs_tol=self.abs_tol,
rel_tol=rel_tol,
abs_tol=abs_tol,
)
for key in self._value
)
Expand Down Expand Up @@ -129,9 +141,20 @@ class Component:
def __eq__(self, other: object) -> bool:
if not isinstance(other, OutputValues.Component):
return NotImplemented
return self.is_close(other, rel_tol=0.0, abs_tol=0.0)

def is_close(
self,
other: "OutputValues.Component",
*,
rel_tol: float = 1.0e-9,
abs_tol: float = 0.0,
) -> bool:
return (self.ignore or other.ignore) or (
self._id == other._id
and _are_mappings_equal(self._variables, other._variables)
and _are_mappings_close(
self._variables, other._variables, rel_tol, abs_tol
)
)

def __str__(self) -> str:
Expand All @@ -156,7 +179,14 @@ def __post_init__(self) -> None:
def __eq__(self, other: object) -> bool:
if not isinstance(other, OutputValues):
return NotImplemented
return _are_mappings_equal(self._components, other._components)
return _are_mappings_close(self._components, other._components, 0.0, 0.0)

def is_close(
self, other: "OutputValues", *, rel_tol: float = 1.0e-9, abs_tol: float = 0.0
) -> bool:
return _are_mappings_close(
self._components, other._components, rel_tol, abs_tol
)

def __str__(self) -> str:
string = "\n"
Expand Down Expand Up @@ -185,9 +215,14 @@ def component(self, component_id: str) -> "OutputValues.Component":
return self._components[component_id]


def _are_mappings_equal(
lhs: Mapping[str, Union[OutputValues.Component, OutputValues.Variable]],
rhs: Mapping[str, Union[OutputValues.Component, OutputValues.Variable]],
Comparable = TypeVar("Comparable", OutputValues.Component, OutputValues.Variable)


def _are_mappings_close(
lhs: Mapping[str, Comparable],
rhs: Mapping[str, Comparable],
rel_tol: float,
abs_tol: float,
) -> bool:
lhs_keys = lhs.keys()
rhs_keys = rhs.keys()
Expand All @@ -202,10 +237,13 @@ def _are_mappings_equal(
):
return False

elif (intersect_keys := lhs_keys & rhs_keys) and any(
lhs[key] != rhs[key] for key in intersect_keys
):
return False

elif intersect_keys := lhs_keys & rhs_keys:
if rel_tol == abs_tol == 0.0:
return all(lhs[key] == rhs[key] for key in intersect_keys)
else:
return all(
lhs[key].is_close(rhs[key], rel_tol=rel_tol, abs_tol=abs_tol)
for key in intersect_keys
)
else:
return True
11 changes: 7 additions & 4 deletions tests/andromede/test_output_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,21 @@ def test_component_and_flow_output_object() -> None:
"component_approx_var_name"
).value = 1.000_000_001

assert (
output != test_output
assert output != test_output and not output.is_close(
test_output
), f"Output is equal to expected outside tolerance: {output}"

test_output.component("component_id_test").var(
"component_approx_var_name"
).value = 1.000_000_000_1

assert (
output == test_output
assert output != test_output and output.is_close(
test_output
), f"Output differs from the expected inside tolerance: {output}"

test_output.component("component_id_test").var(
"component_approx_var_name"
).ignore = True
test_output.component("component_id_test").var(
"wrong_component_var_name"
).value = 1.0
Expand Down

0 comments on commit bc82d7b

Please sign in to comment.