From bc82d7bca6d1a5d96c82e244971d69ede7308a37 Mon Sep 17 00:00:00 2001 From: Ian Menezes Date: Wed, 21 Feb 2024 13:56:44 +0100 Subject: [PATCH] Explicitly call a approx function Revert "Expose tolerance for float comparison to the user" This reverts commit d7bda60748596a7c2f2908d484e567fc6677c72e. Revert "Add tolerance to ExpectedOutput for floating point values" This reverts commit 5781c90b387aaa9567aa96d4570bd71a84499750. --- src/andromede/simulation/output_values.py | 70 +++++++++++++++++------ tests/andromede/test_output_values.py | 11 ++-- 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/src/andromede/simulation/output_values.py b/src/andromede/simulation/output_values.py index 9e4da2cf..a569cdb2 100644 --- a/src/andromede/simulation/output_values.py +++ b/src/andromede/simulation/output_values.py @@ -15,7 +15,7 @@ """ import math from dataclasses import dataclass, field -from typing import Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast from andromede.simulation.optimization import SolverAndContext from andromede.study.data import TimeScenarioIndex @@ -41,14 +41,26 @@ class Variable: _name: str _value: Dict[TimeScenarioIndex, float] = field(init=False, default_factory=dict) _size: Tuple[int, int] = field(init=False, default=(0, 0)) - ignore: bool = field(default=False, init=False) - rel_tol: float = field(default=1.0e-9, init=False) - abs_tol: float = field(default=0.0, init=False) def __eq__(self, other: object) -> bool: if not isinstance(other, OutputValues.Variable): return NotImplemented + return (self.ignore or other.ignore) or ( + self._name == other._name + and self._size == other._size + and self._value == other._value + ) + + def is_close( + self, + other: "OutputValues.Variable", + *, + rel_tol: float = 1.0e-9, + abs_tol: float = 0.0, + ) -> bool: + # From the docs in https://docs.python.org/3/library/math.html#math.isclose + # math.isclose(a, b) returns abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) return (self.ignore or other.ignore) or ( self._name == other._name and self._size == other._size @@ -57,8 +69,8 @@ def __eq__(self, other: object) -> bool: math.isclose( self._value[key], other._value[key], - rel_tol=self.rel_tol, - abs_tol=self.abs_tol, + rel_tol=rel_tol, + abs_tol=abs_tol, ) for key in self._value ) @@ -129,9 +141,20 @@ class Component: def __eq__(self, other: object) -> bool: if not isinstance(other, OutputValues.Component): return NotImplemented + return self.is_close(other, rel_tol=0.0, abs_tol=0.0) + + def is_close( + self, + other: "OutputValues.Component", + *, + rel_tol: float = 1.0e-9, + abs_tol: float = 0.0, + ) -> bool: return (self.ignore or other.ignore) or ( self._id == other._id - and _are_mappings_equal(self._variables, other._variables) + and _are_mappings_close( + self._variables, other._variables, rel_tol, abs_tol + ) ) def __str__(self) -> str: @@ -156,7 +179,14 @@ def __post_init__(self) -> None: def __eq__(self, other: object) -> bool: if not isinstance(other, OutputValues): return NotImplemented - return _are_mappings_equal(self._components, other._components) + return _are_mappings_close(self._components, other._components, 0.0, 0.0) + + def is_close( + self, other: "OutputValues", *, rel_tol: float = 1.0e-9, abs_tol: float = 0.0 + ) -> bool: + return _are_mappings_close( + self._components, other._components, rel_tol, abs_tol + ) def __str__(self) -> str: string = "\n" @@ -185,9 +215,14 @@ def component(self, component_id: str) -> "OutputValues.Component": return self._components[component_id] -def _are_mappings_equal( - lhs: Mapping[str, Union[OutputValues.Component, OutputValues.Variable]], - rhs: Mapping[str, Union[OutputValues.Component, OutputValues.Variable]], +Comparable = TypeVar("Comparable", OutputValues.Component, OutputValues.Variable) + + +def _are_mappings_close( + lhs: Mapping[str, Comparable], + rhs: Mapping[str, Comparable], + rel_tol: float, + abs_tol: float, ) -> bool: lhs_keys = lhs.keys() rhs_keys = rhs.keys() @@ -202,10 +237,13 @@ def _are_mappings_equal( ): return False - elif (intersect_keys := lhs_keys & rhs_keys) and any( - lhs[key] != rhs[key] for key in intersect_keys - ): - return False - + elif intersect_keys := lhs_keys & rhs_keys: + if rel_tol == abs_tol == 0.0: + return all(lhs[key] == rhs[key] for key in intersect_keys) + else: + return all( + lhs[key].is_close(rhs[key], rel_tol=rel_tol, abs_tol=abs_tol) + for key in intersect_keys + ) else: return True diff --git a/tests/andromede/test_output_values.py b/tests/andromede/test_output_values.py index 6ab9de0d..feea3a1e 100644 --- a/tests/andromede/test_output_values.py +++ b/tests/andromede/test_output_values.py @@ -76,18 +76,21 @@ def test_component_and_flow_output_object() -> None: "component_approx_var_name" ).value = 1.000_000_001 - assert ( - output != test_output + assert output != test_output and not output.is_close( + test_output ), f"Output is equal to expected outside tolerance: {output}" test_output.component("component_id_test").var( "component_approx_var_name" ).value = 1.000_000_000_1 - assert ( - output == test_output + assert output != test_output and output.is_close( + test_output ), f"Output differs from the expected inside tolerance: {output}" + test_output.component("component_id_test").var( + "component_approx_var_name" + ).ignore = True test_output.component("component_id_test").var( "wrong_component_var_name" ).value = 1.0