Skip to content

Commit

Permalink
Feature/improve output value with tolerance (#14)
Browse files Browse the repository at this point in the history
* Add tolerance to ExpectedOutput for floating point values

* Add possibility to ignore component/variables

* Expose tolerance for float comparison to the user

* Explicitly call a approx function

Revert "Expose tolerance for float comparison to the user"

This reverts commit d7bda60.

Revert "Add tolerance to ExpectedOutput for floating point values"

This reverts commit 5781c90.
  • Loading branch information
ianmnz authored Feb 26, 2024
1 parent f24d4f0 commit 98b8a66
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 16 deletions.
99 changes: 90 additions & 9 deletions src/andromede/simulation/output_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
"""
Util class to obtain solver results
"""
import math
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast
from typing import Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast

from andromede.simulation.optimization import SolverAndContext
from andromede.study.data import TimeScenarioIndex

T = TypeVar("T")
K = TypeVar("K")


@dataclass
class OutputValues:
Expand All @@ -43,18 +41,45 @@ class Variable:
_name: str
_value: Dict[TimeScenarioIndex, float] = field(init=False, default_factory=dict)
_size: Tuple[int, int] = field(init=False, default=(0, 0))
ignore: bool = field(default=False, init=False)

def __eq__(self, other: object) -> bool:
if not isinstance(other, OutputValues.Variable):
return NotImplemented
return (
return (self.ignore or other.ignore) or (
self._name == other._name
and self._size == other._size
and self._value == other._value
)

def is_close(
self,
other: "OutputValues.Variable",
*,
rel_tol: float = 1.0e-9,
abs_tol: float = 0.0,
) -> bool:
# From the docs in https://docs.python.org/3/library/math.html#math.isclose
# math.isclose(a, b) returns abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
return (self.ignore or other.ignore) or (
self._name == other._name
and self._size == other._size
and self._value.keys() == other._value.keys()
and all(
math.isclose(
self._value[key],
other._value[key],
rel_tol=rel_tol,
abs_tol=abs_tol,
)
for key in self._value
)
)

def __str__(self) -> str:
return f"{self._name} : {str(self.value)}"
return (
f"{self._name} : {str(self.value)} {'(ignored)' if self.ignore else ''}"
)

@property
def value(self) -> Union[None, float, List[float], List[List[float]]]:
Expand Down Expand Up @@ -111,14 +136,29 @@ class Component:
_variables: Dict[str, "OutputValues.Variable"] = field(
init=False, default_factory=dict
)
ignore: bool = field(default=False, init=False)

def __eq__(self, other: object) -> bool:
if not isinstance(other, OutputValues.Component):
return NotImplemented
return self._id == other._id and self._variables == other._variables
return self.is_close(other, rel_tol=0.0, abs_tol=0.0)

def is_close(
self,
other: "OutputValues.Component",
*,
rel_tol: float = 1.0e-9,
abs_tol: float = 0.0,
) -> bool:
return (self.ignore or other.ignore) or (
self._id == other._id
and _are_mappings_close(
self._variables, other._variables, rel_tol, abs_tol
)
)

def __str__(self) -> str:
string = f"{self._id} :\n"
string = f"{self._id} : {'(ignored)' if self.ignore else ''}\n"
for var in self._variables.values():
string += f" {str(var)}\n"
return string
Expand All @@ -139,7 +179,14 @@ def __post_init__(self) -> None:
def __eq__(self, other: object) -> bool:
if not isinstance(other, OutputValues):
return NotImplemented
return self._components == other._components
return _are_mappings_close(self._components, other._components, 0.0, 0.0)

def is_close(
self, other: "OutputValues", *, rel_tol: float = 1.0e-9, abs_tol: float = 0.0
) -> bool:
return _are_mappings_close(
self._components, other._components, rel_tol, abs_tol
)

def __str__(self) -> str:
string = "\n"
Expand All @@ -166,3 +213,37 @@ def component(self, component_id: str) -> "OutputValues.Component":
if component_id not in self._components:
self._components[component_id] = OutputValues.Component(component_id)
return self._components[component_id]


Comparable = TypeVar("Comparable", OutputValues.Component, OutputValues.Variable)


def _are_mappings_close(
lhs: Mapping[str, Comparable],
rhs: Mapping[str, Comparable],
rel_tol: float,
abs_tol: float,
) -> bool:
lhs_keys = lhs.keys()
rhs_keys = rhs.keys()

if (lhs_only_keys := lhs_keys - rhs_keys) and any(
not lhs[key].ignore for key in lhs_only_keys
):
return False

elif (rhs_only_keys := rhs_keys - lhs_keys) and any(
not rhs[key].ignore for key in rhs_only_keys
):
return False

elif intersect_keys := lhs_keys & rhs_keys:
if rel_tol == abs_tol == 0.0:
return all(lhs[key] == rhs[key] for key in intersect_keys)
else:
return all(
lhs[key].is_close(rhs[key], rel_tol=rel_tol, abs_tol=abs_tol)
for key in intersect_keys
)
else:
return True
60 changes: 53 additions & 7 deletions tests/andromede/test_output_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,70 @@ def test_component_and_flow_output_object() -> None:
variable_name="component_var_name",
block_timestep=0,
scenario=0,
): mock_variable_component
): mock_variable_component,
TimestepComponentVariableKey(
component_id="component_id_test",
variable_name="component_approx_var_name",
block_timestep=0,
scenario=0,
): mock_variable_component,
}

opt_context.block_length.return_value = 1

problem = SolverAndContext(mock_variable_flow, opt_context)
output = OutputValues(problem)

wrong_output = OutputValues()
wrong_output.component("component_id_test").var(
test_output = OutputValues()
assert output != test_output, f"Output is equal to empty output: {output}"

test_output.component("component_id_test").ignore = True
assert (
output == test_output
), f"Output differs from the expected output after 'ignore': {output}"

test_output.component("component_id_test").ignore = False
test_output.component("component_id_test").var("component_var_name").value = 1.0
test_output.component("component_id_test").var(
"component_approx_var_name"
).ignore = True

assert (
output == test_output
), f"Output differs from the expected after 'var_name': {output}"

test_output.component("component_id_test").var(
"component_approx_var_name"
).ignore = False
test_output.component("component_id_test").var(
"component_approx_var_name"
).value = 1.000_000_001

assert output != test_output and not output.is_close(
test_output
), f"Output is equal to expected outside tolerance: {output}"

test_output.component("component_id_test").var(
"component_approx_var_name"
).value = 1.000_000_000_1

assert output != test_output and output.is_close(
test_output
), f"Output differs from the expected inside tolerance: {output}"

test_output.component("component_id_test").var(
"component_approx_var_name"
).ignore = True
test_output.component("component_id_test").var(
"wrong_component_var_name"
).value = 1.0

assert output != wrong_output, f"Output is equal to wrong output: {output}"
assert output != test_output, f"Output is equal to wrong output: {output}"

expected_output = OutputValues()
expected_output.component("component_id_test").var("component_var_name").value = 1.0
test_output.component("component_id_test").var(
"wrong_component_var_name"
).ignore = True

assert output == expected_output, f"Output differs from expected: {output}"
assert output == test_output, f"Output differs from expected: {output}"

print(output)

0 comments on commit 98b8a66

Please sign in to comment.