diff --git a/CHANGELOG.md b/CHANGELOG.md index 07cce894..22bf9447 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changes - update cruft and loosen up pyproject dependencies +- harmonize signatures/docs of pydantic core/json schema manipulating methods ### Deprecated @@ -19,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- fix schema tests not starting with diverging model names in common and mex-model +- fix serialization for temporal entity instances within pydantic models + ### Security ## [0.34.0] - 2024-08-12 diff --git a/mex/common/types/email.py b/mex/common/types/email.py index 51fac461..6242f4f6 100644 --- a/mex/common/types/email.py +++ b/mex/common/types/email.py @@ -1,7 +1,6 @@ from typing import Any -from pydantic import GetJsonSchemaHandler -from pydantic.json_schema import JsonSchemaValue +from pydantic import GetJsonSchemaHandler, json_schema from pydantic_core import core_schema EMAIL_PATTERN = r"^[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+$" @@ -11,17 +10,17 @@ class Email(str): """Email address of a person, organization or other entity.""" @classmethod - def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSchema: - """Get pydantic core schema.""" + def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema: + """Modify the core schema to add the email regex.""" return core_schema.str_schema(pattern=EMAIL_PATTERN) @classmethod def __get_pydantic_json_schema__( cls, core_schema_: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """Add title and format.""" - field_schema = handler(core_schema_) - field_schema["title"] = cls.__name__ - field_schema["format"] = "email" - field_schema["examples"] = ["info@rki.de"] - return field_schema + ) -> json_schema.JsonSchemaValue: + """Modify the json schema to add a title, format and examples.""" + json_schema_ = handler(core_schema_) + json_schema_["title"] = cls.__name__ + json_schema_["format"] = "email" + json_schema_["examples"] = ["info@rki.de"] + return json_schema_ diff --git a/mex/common/types/identifier.py b/mex/common/types/identifier.py index 2c60b513..a2f6b0b2 100644 --- a/mex/common/types/identifier.py +++ b/mex/common/types/identifier.py @@ -3,11 +3,10 @@ from typing import Any, Self from uuid import UUID, uuid4 -from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema, core_schema +from pydantic import GetJsonSchemaHandler, json_schema +from pydantic_core import core_schema -ALPHABET = string.ascii_letters + string.digits +MEX_ID_ALPHABET = string.ascii_letters + string.digits MEX_ID_PATTERN = r"^[a-zA-Z0-9]{14,22}$" UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" @@ -20,14 +19,14 @@ def generate(cls, seed: int | None = None) -> Self: """Generate a new identifier from a seed or random UUID version 4.""" # Inspired by https://pypi.org/project/shortuuid output = "" - alpha_len = len(ALPHABET) + alpha_len = len(MEX_ID_ALPHABET) if seed is None: number = uuid4().int else: number = UUID(int=seed, version=4).int while number: number, digit = divmod(number, alpha_len) - output += ALPHABET[digit] + output += MEX_ID_ALPHABET[digit] return cls(output[::-1]) @classmethod @@ -43,28 +42,21 @@ def validate(cls, value: Any) -> Self: raise ValueError(f"Cannot parse {type(value)} as {cls.__name__}") @classmethod - def __get_pydantic_core_schema__( - cls, source: type[Any], handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - """Modify the schema to add the ID regex.""" - identifier_schema = { - "type": "str", - "pattern": MEX_ID_PATTERN, - } + def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema: + """Modify the core schema to add the ID regex.""" return core_schema.no_info_before_validator_function( - cls.validate, - identifier_schema, + cls.validate, core_schema.str_schema(pattern=MEX_ID_PATTERN) ) @classmethod def __get_pydantic_json_schema__( - cls, core_schema_: CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """Modify the schema to add the class name as title.""" - json_schema = handler(core_schema_) - json_schema = handler.resolve_ref_schema(json_schema) - json_schema["title"] = cls.__name__ - return json_schema + cls, core_schema_: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> json_schema.JsonSchemaValue: + """Modify the json schema to add the class name as title.""" + json_schema_ = handler(core_schema_) + json_schema_ = handler.resolve_ref_schema(json_schema_) + json_schema_["title"] = cls.__name__ + return json_schema_ def __repr__(self) -> str: """Overwrite the default representation.""" diff --git a/mex/common/types/path.py b/mex/common/types/path.py index ee6dfe4e..0d963528 100644 --- a/mex/common/types/path.py +++ b/mex/common/types/path.py @@ -49,7 +49,7 @@ def is_relative(self) -> bool: return not self._path.is_absolute() @classmethod - def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema: """Set schema to str schema.""" from_str_schema = core_schema.chain_schema( [ diff --git a/mex/common/types/temporal_entity.py b/mex/common/types/temporal_entity.py index 5450bc58..28690864 100644 --- a/mex/common/types/temporal_entity.py +++ b/mex/common/types/temporal_entity.py @@ -6,9 +6,8 @@ from typing import Any, Literal, Union, cast, overload from pandas._libs.tslibs import parsing -from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema, core_schema +from pydantic import GetJsonSchemaHandler, json_schema +from pydantic_core import core_schema from pytz import timezone @@ -190,10 +189,8 @@ def _validate_precision(cls, precision: TemporalEntityPrecision) -> None: raise ValueError(error_str) @classmethod - def __get_pydantic_core_schema__( - cls, _source: type[Any], _handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - """Mutate the field schema for temporal entity.""" + def __get_pydantic_core_schema__(cls, source: type[Any]) -> core_schema.CoreSchema: + """Modify the core schema to add validation and serialization rules.""" from_str_schema = core_schema.chain_schema( [ core_schema.str_schema(pattern=cls.STR_SCHEMA_PATTERN), @@ -208,16 +205,20 @@ def __get_pydantic_core_schema__( core_schema.is_instance_schema(cls), ] ) + serialization_schema = core_schema.plain_serializer_function_ser_schema( + lambda instance: str(instance) + ) return core_schema.json_or_python_schema( json_schema=from_str_schema, python_schema=from_anything_schema, + serialization=serialization_schema, ) @classmethod def __get_pydantic_json_schema__( - cls, core_schema_: CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """Modify the schema to add the class name as title and examples.""" + cls, core_schema_: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> json_schema.JsonSchemaValue: + """Modify the json schema to add a title, examples and an optional format.""" json_schema = handler(core_schema_) json_schema["title"] = cls.__name__ json_schema.update(cls.JSON_SCHEMA_CONFIG) diff --git a/tests/models/test_model_schemas.py b/tests/models/test_model_schemas.py index 126f394c..5e31b283 100644 --- a/tests/models/test_model_schemas.py +++ b/tests/models/test_model_schemas.py @@ -67,8 +67,8 @@ def test_entity_types_match_spec() -> None: @pytest.mark.parametrize( ("generated", "specified"), - zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values()), - ids=GENERATED_SCHEMAS, + zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values(), fillvalue={}), + ids=map(str, zip_longest(GENERATED_SCHEMAS, SPECIFIED_SCHEMAS, fillvalue="N/A")), ) def test_field_names_match_spec( generated: dict[str, Any], specified: dict[str, Any] @@ -81,8 +81,8 @@ def test_field_names_match_spec( @pytest.mark.parametrize( ("generated", "specified"), - zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values()), - ids=GENERATED_SCHEMAS, + zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values(), fillvalue={}), + ids=map(str, zip_longest(GENERATED_SCHEMAS, SPECIFIED_SCHEMAS, fillvalue="N/A")), ) def test_entity_type_matches_class_name( generated: dict[str, Any], specified: dict[str, Any] @@ -95,8 +95,8 @@ def test_entity_type_matches_class_name( @pytest.mark.parametrize( ("generated", "specified"), - zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values()), - ids=GENERATED_SCHEMAS, + zip_longest(GENERATED_SCHEMAS.values(), SPECIFIED_SCHEMAS.values(), fillvalue={}), + ids=map(str, zip_longest(GENERATED_SCHEMAS, SPECIFIED_SCHEMAS, fillvalue="N/A")), ) def test_required_fields_match_spec( generated: dict[str, Any], specified: dict[str, Any] diff --git a/tests/types/test_temporal_entity.py b/tests/types/test_temporal_entity.py index df08d71b..3f3f18be 100644 --- a/tests/types/test_temporal_entity.py +++ b/tests/types/test_temporal_entity.py @@ -2,6 +2,7 @@ from typing import Any import pytest +from pydantic import BaseModel from pytz import timezone from mex.common.types import ( @@ -268,3 +269,12 @@ def test_temporal_entity_repr() -> None: repr(YearMonthDayTime("2018-03-02T12:00:01Z")) == 'YearMonthDayTime("2018-03-02T12:00:01Z")' ) + + +def test_temporal_entity_serialization() -> None: + class Person(BaseModel): + birthday: YearMonthDay + + person = Person.model_validate({"birthday": "24th July 1999"}) + + assert person.model_dump_json() == '{"birthday":"1999-07-24"}' diff --git a/tests/wikidata/test_connector.py b/tests/wikidata/test_connector.py index 76137ecf..48154114 100644 --- a/tests/wikidata/test_connector.py +++ b/tests/wikidata/test_connector.py @@ -24,6 +24,22 @@ def test_initialization_mocked_server( def test_get_data_by_query() -> None: """Test if items can be searched providing a label.""" expected = [ + { + "item": { + "type": "uri", + "value": "http://www.wikidata.org/entity/Q2875797", + }, + "itemDescription": { + "type": "literal", + "value": "airport in Algeria", + "xml:lang": "en", + }, + "itemLabel": { + "type": "literal", + "value": "Bordj Mokhtar Airport", + "xml:lang": "en", + }, + }, { "item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q26678"}, "itemDescription": {