Skip to content

Commit

Permalink
Condition: switch call in RDBS factory to convert args to JSON, to in…
Browse files Browse the repository at this point in the history
…ternal ConditionModel function, to better handle errors there.
  • Loading branch information
sebastian-echeverria committed Nov 27, 2024
1 parent 6537876 commit 854da4a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
13 changes: 13 additions & 0 deletions mlte/spec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Model implementation for the Spec artifact.
"""

import json
from typing import Any, Dict, List, Literal, Optional

from mlte.artifact.type import ArtifactType
Expand All @@ -25,6 +26,18 @@ class ConditionModel(BaseModel):
value_class: str
"""A string indicating the full module and class name of the Value used to generate this condition."""

def args_to_json_str(self) -> str:
"""
Serialize the model arguments field into a string.
:return: The JSON str representation of the model
"""
# First convert whole thing, to see if arguments will trigger error (and if so, just let it bubble up).
self.to_json()

# Now convert only the actual arguments.
json_args = json.dumps(self.arguments)
return json_args


class PropertyModel(BaseModel):
"""A description of a property."""
Expand Down
2 changes: 1 addition & 1 deletion mlte/store/artifact/underlying/rdbs/factory_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_spec_db_from_model(
condition_obj = DBCondition(
name=condition.name,
measurement_id=measurement_id,
arguments=json.dumps(condition.arguments),
arguments=condition.args_to_json_str(),
callback=condition.callback,
value_class=condition.value_class,
property=property_obj,
Expand Down
32 changes: 31 additions & 1 deletion test/spec/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

from __future__ import annotations

from typing import Any

import pytest

from mlte._private import serializing
from mlte.evidence.metadata import EvidenceMetadata, Identifier
from mlte.model.serialization_error import SerializationError
from mlte.spec.condition import Condition
from mlte.spec.model import ConditionModel
from mlte.validation.result import Failure, Success
Expand All @@ -16,6 +22,8 @@
class TestValue:
"""Test value class to test build_condition method."""

data: Any

@classmethod
def in_between(cls, arg1: float, arg2: float) -> Condition:
"""Checks if the value is in between the arguments."""
Expand All @@ -30,6 +38,20 @@ def in_between(cls, arg1: float, arg2: float) -> Condition:
)
return condition

@classmethod
def in_between_complex(cls, arg1: float, arg2: TestValue) -> Condition:
"""Checks if the value is in between the arguments."""
condition: Condition = Condition.build_condition(
lambda real: Success(
f"Real magnitude {real.value} between {arg1} and {arg2}"
)
if real.value > arg1 and real.value < arg2
else Failure(
f"Real magnitude {real.value} not between {arg1} and {arg2}"
)
)
return condition


def test_condition_model() -> None:
"""A Condition model can be serialized and deserialized."""
Expand All @@ -43,7 +65,7 @@ def test_condition_model() -> None:
ConditionModel(
name="greater_than",
arguments=[1, 2],
callback=Condition.encode_callback(
callback=serializing.encode_callable(
lambda real: Success("Real magnitude 2 less than threshold 3")
if 3 < 4
else Failure("Real magnitude 2 exceeds threshold 1")
Expand Down Expand Up @@ -87,3 +109,11 @@ def test_call_condition():

result = condition(Real(ev, 11.0))
assert str(result) == "Failure"


def test_non_serializable_argument():
test_value = TestValue()
condition = test_value.in_between_complex(1.0, test_value)

with pytest.raises(SerializationError):
_ = condition.to_model().to_json()

0 comments on commit 854da4a

Please sign in to comment.