Skip to content

Commit

Permalink
feature/mx-1417 fix temporal serialization (#273)
Browse files Browse the repository at this point in the history
# PR Context
- fixes
robert-koch-institut/mex-backend#106 (review)

# Changes

- harmonize signatures/docs of pydantic core/json schema manipulating
methods

# Fixed

- fix schema tests not starting with diverging model names in common and
mex-model
- fix serialization for temporal entity instances within pydantic models

---------

Signed-off-by: Nicolas Drebenstedt <[email protected]>
  • Loading branch information
cutoffthetop authored Aug 20, 2024
1 parent d33ce89 commit 5234dc6
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 51 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changes

- update cruft and loosen up pyproject dependencies
- harmonize signatures/docs of pydantic core/json schema manipulating methods

### Deprecated

### Removed

### Fixed

- fix schema tests not starting with diverging model names in common and mex-model
- fix serialization for temporal entity instances within pydantic models

### Security

## [0.34.0] - 2024-08-12
Expand Down
21 changes: 10 additions & 11 deletions mex/common/types/email.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any

from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic import GetJsonSchemaHandler, json_schema
from pydantic_core import core_schema

EMAIL_PATTERN = r"^[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+$"
Expand All @@ -11,17 +10,17 @@ class Email(str):
"""Email address of a person, organization or other entity."""

@classmethod
def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSchema:
"""Get pydantic core schema."""
def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema:
"""Modify the core schema to add the email regex."""
return core_schema.str_schema(pattern=EMAIL_PATTERN)

@classmethod
def __get_pydantic_json_schema__(
cls, core_schema_: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""Add title and format."""
field_schema = handler(core_schema_)
field_schema["title"] = cls.__name__
field_schema["format"] = "email"
field_schema["examples"] = ["[email protected]"]
return field_schema
) -> json_schema.JsonSchemaValue:
"""Modify the json schema to add a title, format and examples."""
json_schema_ = handler(core_schema_)
json_schema_["title"] = cls.__name__
json_schema_["format"] = "email"
json_schema_["examples"] = ["[email protected]"]
return json_schema_
38 changes: 15 additions & 23 deletions mex/common/types/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from typing import Any, Self
from uuid import UUID, uuid4

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema, core_schema
from pydantic import GetJsonSchemaHandler, json_schema
from pydantic_core import core_schema

ALPHABET = string.ascii_letters + string.digits
MEX_ID_ALPHABET = string.ascii_letters + string.digits
MEX_ID_PATTERN = r"^[a-zA-Z0-9]{14,22}$"
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"

Expand All @@ -20,14 +19,14 @@ def generate(cls, seed: int | None = None) -> Self:
"""Generate a new identifier from a seed or random UUID version 4."""
# Inspired by https://pypi.org/project/shortuuid
output = ""
alpha_len = len(ALPHABET)
alpha_len = len(MEX_ID_ALPHABET)
if seed is None:
number = uuid4().int
else:
number = UUID(int=seed, version=4).int
while number:
number, digit = divmod(number, alpha_len)
output += ALPHABET[digit]
output += MEX_ID_ALPHABET[digit]
return cls(output[::-1])

@classmethod
Expand All @@ -43,28 +42,21 @@ def validate(cls, value: Any) -> Self:
raise ValueError(f"Cannot parse {type(value)} as {cls.__name__}")

@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""Modify the schema to add the ID regex."""
identifier_schema = {
"type": "str",
"pattern": MEX_ID_PATTERN,
}
def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema:
"""Modify the core schema to add the ID regex."""
return core_schema.no_info_before_validator_function(
cls.validate,
identifier_schema,
cls.validate, core_schema.str_schema(pattern=MEX_ID_PATTERN)
)

@classmethod
def __get_pydantic_json_schema__(
cls, core_schema_: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""Modify the schema to add the class name as title."""
json_schema = handler(core_schema_)
json_schema = handler.resolve_ref_schema(json_schema)
json_schema["title"] = cls.__name__
return json_schema
cls, core_schema_: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> json_schema.JsonSchemaValue:
"""Modify the json schema to add the class name as title."""
json_schema_ = handler(core_schema_)
json_schema_ = handler.resolve_ref_schema(json_schema_)
json_schema_["title"] = cls.__name__
return json_schema_

def __repr__(self) -> str:
"""Overwrite the default representation."""
Expand Down
2 changes: 1 addition & 1 deletion mex/common/types/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def is_relative(self) -> bool:
return not self._path.is_absolute()

@classmethod
def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSchema:
def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema:
"""Set schema to str schema."""
from_str_schema = core_schema.chain_schema(
[
Expand Down
21 changes: 11 additions & 10 deletions mex/common/types/temporal_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Any, Literal, Union, cast, overload

from pandas._libs.tslibs import parsing
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema, core_schema
from pydantic import GetJsonSchemaHandler, json_schema
from pydantic_core import core_schema
from pytz import timezone


Expand Down Expand Up @@ -190,10 +189,8 @@ def _validate_precision(cls, precision: TemporalEntityPrecision) -> None:
raise ValueError(error_str)

@classmethod
def __get_pydantic_core_schema__(
cls, _source: type[Any], _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""Mutate the field schema for temporal entity."""
def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema:
"""Modify the core schema to add validation and serialization rules."""
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(pattern=cls.STR_SCHEMA_PATTERN),
Expand All @@ -208,16 +205,20 @@ def __get_pydantic_core_schema__(
core_schema.is_instance_schema(cls),
]
)
serialization_schema = core_schema.plain_serializer_function_ser_schema(
lambda instance: str(instance)
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=from_anything_schema,
serialization=serialization_schema,
)

@classmethod
def __get_pydantic_json_schema__(
cls, core_schema_: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""Modify the schema to add the class name as title and examples."""
cls, core_schema_: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> json_schema.JsonSchemaValue:
"""Modify the json schema to add a title, examples and an optional format."""
json_schema = handler(core_schema_)
json_schema["title"] = cls.__name__
json_schema.update(cls.JSON_SCHEMA_CONFIG)
Expand Down
12 changes: 6 additions & 6 deletions tests/models/test_model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_entity_types_match_spec() -> None:

@pytest.mark.parametrize(
("generated", "specified"),
zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values()),
ids=GENERATED_SCHEMAS,
zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values(), fillvalue={}),
ids=map(str, zip_longest(GENERATED_SCHEMAS, SPECIFIED_SCHEMAS, fillvalue="N/A")),
)
def test_field_names_match_spec(
generated: dict[str, Any], specified: dict[str, Any]
Expand All @@ -81,8 +81,8 @@ def test_field_names_match_spec(

@pytest.mark.parametrize(
("generated", "specified"),
zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values()),
ids=GENERATED_SCHEMAS,
zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values(), fillvalue={}),
ids=map(str, zip_longest(GENERATED_SCHEMAS, SPECIFIED_SCHEMAS, fillvalue="N/A")),
)
def test_entity_type_matches_class_name(
generated: dict[str, Any], specified: dict[str, Any]
Expand All @@ -95,8 +95,8 @@ def test_entity_type_matches_class_name(

@pytest.mark.parametrize(
("generated", "specified"),
zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values()),
ids=GENERATED_SCHEMAS,
zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values(), fillvalue={}),
ids=map(str, zip_longest(GENERATED_SCHEMAS, SPECIFIED_SCHEMAS, fillvalue="N/A")),
)
def test_required_fields_match_spec(
generated: dict[str, Any], specified: dict[str, Any]
Expand Down
10 changes: 10 additions & 0 deletions tests/types/test_temporal_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import pytest
from pydantic import BaseModel
from pytz import timezone

from mex.common.types import (
Expand Down Expand Up @@ -268,3 +269,12 @@ def test_temporal_entity_repr() -> None:
repr(YearMonthDayTime("2018-03-02T12:00:01Z"))
== 'YearMonthDayTime("2018-03-02T12:00:01Z")'
)


def test_temporal_entity_serialization() -> None:
class Person(BaseModel):
birthday: YearMonthDay

person = Person.model_validate({"birthday": "24th July 1999"})

assert person.model_dump_json() == '{"birthday":"1999-07-24"}'
16 changes: 16 additions & 0 deletions tests/wikidata/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ def test_initialization_mocked_server(
def test_get_data_by_query() -> None:
"""Test if items can be searched providing a label."""
expected = [
{
"item": {
"type": "uri",
"value": "http://www.wikidata.org/entity/Q2875797",
},
"itemDescription": {
"type": "literal",
"value": "airport in Algeria",
"xml:lang": "en",
},
"itemLabel": {
"type": "literal",
"value": "Bordj Mokhtar Airport",
"xml:lang": "en",
},
},
{
"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q26678"},
"itemDescription": {
Expand Down

0 comments on commit 5234dc6

Please sign in to comment.