Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 support for MSONable types #548

Merged
merged 12 commits into from
Sep 5, 2023
87 changes: 65 additions & 22 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
except ImportError:
pydantic = None # type: ignore

try:
from pydantic_core import core_schema
except ImportError:
core_schema = None # type: ignore

try:
import bson
except ImportError:
Expand Down Expand Up @@ -227,41 +232,79 @@ def flatten(obj, separator="."):
return sha1(json.dumps(OrderedDict(ordered_keys)).encode("utf-8"))

@classmethod
def __get_validators__(cls):
"""Return validators for use in pydantic"""
yield cls.validate_monty

@classmethod
def validate_monty(cls, v):
def _validate_monty(cls, __input_value):
"""
pydantic Validator for MSONable pattern
"""
if isinstance(v, cls):
return v
if isinstance(v, dict):
new_obj = MontyDecoder().process_decoded(v)
if isinstance(__input_value, cls):
return __input_value
if isinstance(__input_value, dict):
new_obj = MontyDecoder().process_decoded(__input_value)
if isinstance(new_obj, cls):
return new_obj

new_obj = cls(**v)
new_obj = cls(**__input_value)
return new_obj

raise ValueError(f"Must provide {cls.__name__}, the as_dict form, or the proper")

@classmethod
def validate_monty_v1(cls, __input_value):
"""
Pydantic validator with correct signature for pydantic v1.x
"""
return cls._validate_monty(__input_value)

@classmethod
def validate_monty_v2(cls, __input_value, _):
"""
Pydantic validator with correct signature for pydantic v2.x
"""
return cls._validate_monty(__input_value)

@classmethod
def __get_validators__(cls):
"""Return validators for use in pydantic"""
yield cls.validate_monty_v1

@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
"""
pydantic v2 core schema definition
"""
if core_schema is None:
raise RuntimeError("Pydantic >= 2.0 is required for validation")

s = core_schema.general_plain_validator_function(cls.validate_monty_v2)

return core_schema.json_or_python_schema(
json_schema=s,
python_schema=s,
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: instance.as_dict()),
)

@classmethod
def _generic_json_schema(cls):
return {
"type": "object",
"properties": {
"@class": {"enum": [cls.__name__], "type": "string"},
"@module": {"enum": [cls.__module__], "type": "string"},
"@version": {"type": "string"},
},
"required": ["@class", "@module"],
}

@classmethod
def __get_pydantic_json_schema__(cls, core_schema, handler):
"""JSON schema for MSONable pattern"""
return cls._generic_json_schema()

@classmethod
def __modify_schema__(cls, field_schema):
"""JSON schema for MSONable pattern"""
field_schema.update(
{
"type": "object",
"properties": {
"@class": {"enum": [cls.__name__], "type": "string"},
"@module": {"enum": [cls.__module__], "type": "string"},
"@version": {"type": "string"},
},
"required": ["@class", "@module"],
}
)
custom_schema = cls._generic_json_schema()
field_schema.update(custom_schema)


class MontyEncoder(json.JSONEncoder):
Expand Down
1 change: 1 addition & 0 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ coveralls
pycodestyle
mypy
pydocstyle
pydantic
flake8
black
pylint
Expand Down