Skip to content

Commit

Permalink
Base Model: modified a bit to check for lack of serialization, to gen…
Browse files Browse the repository at this point in the history
…erate a standard and more informative exception if artifact or other model can't be serialized.
  • Loading branch information
sebastian-echeverria committed Nov 27, 2024
1 parent c73a9af commit 946d600
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
15 changes: 13 additions & 2 deletions mlte/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,31 @@

from __future__ import annotations

import json
from typing import Any, Dict

import pydantic

from mlte.model.serialization_exception import SerializationException


class BaseModel(pydantic.BaseModel):
"""The base model for all MLTE models."""

def to_json(self) -> Dict[str, Any]:
"""
Serialize the model.
Serialize the model. Also check if the result is serializable.
:return: The JSON representation of the model
"""
return self.model_dump()
json_object = self.model_dump()

# Check if object can't be serialized.
try:
_ = json.dumps(json_object)
except TypeError as e:
raise SerializationException(e, str(type(self)))

return json_object

@classmethod
def from_json(cls, data: Dict[str, Any]) -> BaseModel:
Expand Down
16 changes: 16 additions & 0 deletions mlte/model/serialization_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
mlte/model/api/serialization_exception.py
Exception used for serialization issues.
"""
from __future__ import annotations


class SerializationException(TypeError):
"""Exception used for JSON serialization issues."""

def __init__(self, error: TypeError, object: str):
super().__init__(
f"Object {object} cannot be serialized into JSON, ensure all attributes are serializable: "
+ str(error)
)
43 changes: 43 additions & 0 deletions test/model/test_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
test/model/test_base_model.py
Unit tests for base model functionality.
"""

from __future__ import annotations

from typing import Any

import pytest

from mlte.model.base_model import BaseModel
from mlte.model.serialization_exception import SerializationException


class ModelTest(BaseModel):
int_num: int = 1
float_num: float = 1.2
str_obj: str = "test"
bool_obj: bool = False
obj: Any = ""


class NonSerializable:
attr3: str = "test"
attr1: dict[str, Any] = {"baz": ModelTest()}


def test_to_json():
test_obj = ModelTest()
json_obj = test_obj.to_json()
reconstructed = ModelTest.from_json(json_obj)

assert test_obj == reconstructed


def test_to_json_to_str_not_serializable():
test_obj = ModelTest()
test_obj.obj = NonSerializable()

with pytest.raises(SerializationException):
_ = test_obj.to_json()

0 comments on commit 946d600

Please sign in to comment.