diff --git a/mlte/spec/model.py b/mlte/spec/model.py index 193cfd09..425f371b 100644 --- a/mlte/spec/model.py +++ b/mlte/spec/model.py @@ -4,6 +4,7 @@ Model implementation for the Spec artifact. """ +import json from typing import Any, Dict, List, Literal, Optional from mlte.artifact.type import ArtifactType @@ -25,6 +26,18 @@ class ConditionModel(BaseModel): value_class: str """A string indicating the full module and class name of the Value used to generate this condition.""" + def args_to_json_str(self) -> str: + """ + Serialize the model arguments field into a string. + :return: The JSON str representation of the model + """ + # First convert whole thing, to see if arguments will trigger error (and if so, just let it bubble up). + self.to_json() + + # Now convert only the actual arguments. + json_args = json.dumps(self.arguments) + return json_args + class PropertyModel(BaseModel): """A description of a property.""" diff --git a/mlte/store/artifact/underlying/rdbs/factory_spec.py b/mlte/store/artifact/underlying/rdbs/factory_spec.py index f5b6e23a..64d9c458 100644 --- a/mlte/store/artifact/underlying/rdbs/factory_spec.py +++ b/mlte/store/artifact/underlying/rdbs/factory_spec.py @@ -47,7 +47,7 @@ def create_spec_db_from_model( condition_obj = DBCondition( name=condition.name, measurement_id=measurement_id, - arguments=json.dumps(condition.arguments), + arguments=condition.args_to_json_str(), callback=condition.callback, value_class=condition.value_class, property=property_obj, diff --git a/test/spec/test_condition.py b/test/spec/test_condition.py index e38d5eb9..deb23c39 100644 --- a/test/spec/test_condition.py +++ b/test/spec/test_condition.py @@ -6,7 +6,13 @@ from __future__ import annotations +from typing import Any + +import pytest + +from mlte._private import serializing from mlte.evidence.metadata import EvidenceMetadata, Identifier +from mlte.model.serialization_error import SerializationError from mlte.spec.condition import Condition from mlte.spec.model import ConditionModel from mlte.validation.result import Failure, Success @@ -16,6 +22,8 @@ class TestValue: """Test value class to test build_condition method.""" + data: Any + @classmethod def in_between(cls, arg1: float, arg2: float) -> Condition: """Checks if the value is in between the arguments.""" @@ -30,6 +38,20 @@ def in_between(cls, arg1: float, arg2: float) -> Condition: ) return condition + @classmethod + def in_between_complex(cls, arg1: float, arg2: TestValue) -> Condition: + """Checks if the value is in between the arguments.""" + condition: Condition = Condition.build_condition( + lambda real: Success( + f"Real magnitude {real.value} between {arg1} and {arg2}" + ) + if real.value > arg1 and real.value < arg2 + else Failure( + f"Real magnitude {real.value} not between {arg1} and {arg2}" + ) + ) + return condition + def test_condition_model() -> None: """A Condition model can be serialized and deserialized.""" @@ -43,7 +65,7 @@ def test_condition_model() -> None: ConditionModel( name="greater_than", arguments=[1, 2], - callback=Condition.encode_callback( + callback=serializing.encode_callable( lambda real: Success("Real magnitude 2 less than threshold 3") if 3 < 4 else Failure("Real magnitude 2 exceeds threshold 1") @@ -87,3 +109,11 @@ def test_call_condition(): result = condition(Real(ev, 11.0)) assert str(result) == "Failure" + + +def test_non_serializable_argument(): + test_value = TestValue() + condition = test_value.in_between_complex(1.0, test_value) + + with pytest.raises(SerializationError): + _ = condition.to_model().to_json()