From 946d60072c1f7b2977a1e0f360b3feb452a4750b Mon Sep 17 00:00:00 2001 From: Sebastian Echeverria Date: Wed, 27 Nov 2024 16:41:03 -0500 Subject: [PATCH] Base Model: modified a bit to check for lack of serialization, to generate a standard and more informative exception if artifact or other model can't be serialized. --- mlte/model/base_model.py | 15 ++++++++-- mlte/model/serialization_exception.py | 16 ++++++++++ test/model/test_base_model.py | 43 +++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 mlte/model/serialization_exception.py create mode 100644 test/model/test_base_model.py diff --git a/mlte/model/base_model.py b/mlte/model/base_model.py index 0f232a6d..1db317b6 100644 --- a/mlte/model/base_model.py +++ b/mlte/model/base_model.py @@ -6,20 +6,31 @@ from __future__ import annotations +import json from typing import Any, Dict import pydantic +from mlte.model.serialization_exception import SerializationException + class BaseModel(pydantic.BaseModel): """The base model for all MLTE models.""" def to_json(self) -> Dict[str, Any]: """ - Serialize the model. + Serialize the model. Also check if the result is serializable. :return: The JSON representation of the model """ - return self.model_dump() + json_object = self.model_dump() + + # Check if object can't be serialized. + try: + _ = json.dumps(json_object) + except TypeError as e: + raise SerializationException(e, str(type(self))) + + return json_object @classmethod def from_json(cls, data: Dict[str, Any]) -> BaseModel: diff --git a/mlte/model/serialization_exception.py b/mlte/model/serialization_exception.py new file mode 100644 index 00000000..693a2831 --- /dev/null +++ b/mlte/model/serialization_exception.py @@ -0,0 +1,16 @@ +""" +mlte/model/api/serialization_exception.py + +Exception used for serialization issues. +""" +from __future__ import annotations + + +class SerializationException(TypeError): + """Exception used for JSON serialization issues.""" + + def __init__(self, error: TypeError, object: str): + super().__init__( + f"Object {object} cannot be serialized into JSON, ensure all attributes are serializable: " + + str(error) + ) diff --git a/test/model/test_base_model.py b/test/model/test_base_model.py new file mode 100644 index 00000000..8bd7693e --- /dev/null +++ b/test/model/test_base_model.py @@ -0,0 +1,43 @@ +""" +test/model/test_base_model.py + +Unit tests for base model functionality. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from mlte.model.base_model import BaseModel +from mlte.model.serialization_exception import SerializationException + + +class ModelTest(BaseModel): + int_num: int = 1 + float_num: float = 1.2 + str_obj: str = "test" + bool_obj: bool = False + obj: Any = "" + + +class NonSerializable: + attr3: str = "test" + attr1: dict[str, Any] = {"baz": ModelTest()} + + +def test_to_json(): + test_obj = ModelTest() + json_obj = test_obj.to_json() + reconstructed = ModelTest.from_json(json_obj) + + assert test_obj == reconstructed + + +def test_to_json_to_str_not_serializable(): + test_obj = ModelTest() + test_obj.obj = NonSerializable() + + with pytest.raises(SerializationException): + _ = test_obj.to_json()