From 33de27775ed829255f614a63d9051146138bcd22 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sat, 30 Mar 2024 20:16:01 +1000 Subject: [PATCH] feat: support declaring `DTOField` via `Annotated` (#3289) --- .../contrib/pydantic/pydantic_dto_factory.py | 19 ++++++++- litestar/dto/dataclass_dto.py | 4 +- litestar/dto/field.py | 35 ++++++++++++++++- litestar/dto/msgspec_dto.py | 5 ++- tests/unit/test_contrib/test_msgspec.py | 16 +++++++- .../test_pydantic_dto_factory.py | 39 +++++++++++++++++-- .../test_factory/test_dataclass_dto.py | 18 +++++++++ .../unit/test_dto/test_factory/test_field.py | 12 ++++++ 8 files changed, 137 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_dto/test_factory/test_field.py diff --git a/litestar/contrib/pydantic/pydantic_dto_factory.py b/litestar/contrib/pydantic/pydantic_dto_factory.py index d61f95d671..af7d3e6830 100644 --- a/litestar/contrib/pydantic/pydantic_dto_factory.py +++ b/litestar/contrib/pydantic/pydantic_dto_factory.py @@ -2,13 +2,14 @@ from dataclasses import replace from typing import TYPE_CHECKING, Collection, Generic, TypeVar +from warnings import warn from typing_extensions import TypeAlias, override from litestar.contrib.pydantic.utils import is_pydantic_undefined from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition -from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field from litestar.exceptions import MissingDependencyException, ValidationException from litestar.types.empty import Empty @@ -81,7 +82,21 @@ def generate_field_definitions( for field_name, field_info in model_fields.items(): field_definition = model_field_definitions[field_name] - dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField()) + dto_field = extract_dto_field(field_definition, field_definition.extra) + + try: + extra = field_info.extra # type: ignore[union-attr] + except AttributeError: + extra = field_info.json_schema_extra # type: ignore[union-attr] + + if extra is not None and extra.pop(DTO_FIELD_META_KEY, None): + warn( + message="Declaring 'DTOField' via Pydantic's 'Field.extra' is deprecated. " + "Use 'Annotated', e.g., 'Annotated[str, DTOField(mark='read-only')]' instead. " + "Support for 'DTOField' in 'Field.extra' will be removed in v3.", + category=DeprecationWarning, + stacklevel=2, + ) if not is_pydantic_undefined(field_info.default): default = field_info.default diff --git a/litestar/dto/dataclass_dto.py b/litestar/dto/dataclass_dto.py index 554b0f3343..5301abac83 100644 --- a/litestar/dto/dataclass_dto.py +++ b/litestar/dto/dataclass_dto.py @@ -5,7 +5,7 @@ from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition -from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.dto.field import extract_dto_field from litestar.params import DependencyKwarg, KwargDefinition from litestar.types.empty import Empty @@ -40,7 +40,7 @@ def generate_field_definitions( DTOFieldDefinition.from_field_definition( field_definition=field_definition, default_factory=default_factory, - dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()), + dto_field=extract_dto_field(field_definition, dc_field.metadata), model_name=model_type.__name__, ), name=key, diff --git a/litestar/dto/field.py b/litestar/dto/field.py index 7ef8a390e2..bdaf125913 100644 --- a/litestar/dto/field.py +++ b/litestar/dto/field.py @@ -4,13 +4,21 @@ from dataclasses import dataclass from enum import Enum -from typing import Literal +from typing import TYPE_CHECKING + +from litestar.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from typing import Any, Literal, Mapping + + from litestar.typing import FieldDefinition __all__ = ( "DTO_FIELD_META_KEY", "DTOField", "Mark", "dto_field", + "extract_dto_field", ) DTO_FIELD_META_KEY = "__dto__" @@ -48,3 +56,28 @@ def dto_field(mark: Literal["read-only", "write-only", "private"] | Mark) -> dic Marking a field automates its inclusion/exclusion from DTO field definitions, depending on the DTO's purpose. """ return {DTO_FIELD_META_KEY: DTOField(mark=Mark(mark))} + + +def extract_dto_field(field_definition: FieldDefinition, field_info_mapping: Mapping[str, Any]) -> DTOField: + """Extract ``DTOField`` instance for a model field. + + Supports ``DTOField`` to bet set via ``Annotated`` or via a field info/metadata mapping. + + E.g., ``Annotated[str, DTOField(mark="read-only")]`` or ``info=dto_field(mark="read-only")``. + + If a value is found in ``field_info_mapping``, it is prioritized over the field definition's metadata. + + Args: + field_definition: A field definition. + field_info_mapping: A field metadata/info attribute mapping, e.g., SQLAlchemy's ``info`` attribute, + or dataclasses ``metadata`` attribute. + + Returns: + DTO field info, if any. + """ + if inst := field_info_mapping.get(DTO_FIELD_META_KEY): + if not isinstance(inst, DTOField): + raise ImproperlyConfiguredException(f"DTO field info must be an instance of DTOField, got '{inst}'") + return inst + + return next((f for f in field_definition.metadata if isinstance(f, DTOField)), DTOField()) diff --git a/litestar/dto/msgspec_dto.py b/litestar/dto/msgspec_dto.py index 826a1d274f..9996319747 100644 --- a/litestar/dto/msgspec_dto.py +++ b/litestar/dto/msgspec_dto.py @@ -7,7 +7,7 @@ from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition -from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field from litestar.types.empty import Empty if TYPE_CHECKING: @@ -36,7 +36,8 @@ def default_or_none(value: Any) -> Any: for key, field_definition in cls.get_model_type_hints(model_type).items(): msgspec_field = msgspec_fields[key] - dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField()) + dto_field = extract_dto_field(field_definition, field_definition.extra) + field_definition.extra.pop(DTO_FIELD_META_KEY, None) yield replace( DTOFieldDefinition.from_field_definition( diff --git a/tests/unit/test_contrib/test_msgspec.py b/tests/unit/test_contrib/test_msgspec.py index 783f78f80f..9c28ed4ba9 100644 --- a/tests/unit/test_contrib/test_msgspec.py +++ b/tests/unit/test_contrib/test_msgspec.py @@ -5,7 +5,7 @@ from msgspec import Meta, Struct, field from typing_extensions import Annotated -from litestar.dto import MsgspecDTO, dto_field +from litestar.dto import DTOField, MsgspecDTO, dto_field from litestar.dto.data_structures import DTOFieldDefinition from litestar.typing import FieldDefinition @@ -38,3 +38,17 @@ class NotStruct: assert MsgspecDTO.detect_nested_field(FieldDefinition.from_annotation(TestStruct)) is True assert MsgspecDTO.detect_nested_field(FieldDefinition.from_annotation(NotStruct)) is False + + +ReadOnlyInt = Annotated[int, DTOField("read-only")] + + +def test_msgspec_dto_annotated_dto_field() -> None: + class Model(Struct): + a: Annotated[int, DTOField("read-only")] + b: ReadOnlyInt + + dto_type = MsgspecDTO[Model] + fields = list(dto_type.generate_field_definitions(Model)) + assert fields[0].dto_field == DTOField("read-only") + assert fields[1].dto_field == DTOField("read-only") diff --git a/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py b/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py index 562537f909..ee2ecdc467 100644 --- a/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py +++ b/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py @@ -3,14 +3,17 @@ from typing import TYPE_CHECKING import pydantic as pydantic_v2 +import pytest from pydantic import v1 as pydantic_v1 from typing_extensions import Annotated from litestar.contrib.pydantic import PydanticDTO -from litestar.dto import dto_field +from litestar.dto import DTOField, dto_field from litestar.dto.data_structures import DTOFieldDefinition from litestar.typing import FieldDefinition +from . import PydanticVersion + if TYPE_CHECKING: from typing import Callable @@ -21,7 +24,7 @@ def test_field_definition_generation_v1( ) -> None: class TestModel(pydantic_v1.BaseModel): a: int - b: Annotated[int, pydantic_v1.Field(**dto_field("read-only"))] # pyright: ignore + b: Annotated[int, DTOField("read-only")] c: Annotated[int, pydantic_v1.Field(gt=1)] d: int = pydantic_v1.Field(default=1) e: int = pydantic_v1.Field(default_factory=int_factory) @@ -38,7 +41,7 @@ def test_field_definition_generation_v2( ) -> None: class TestModel(pydantic_v2.BaseModel): a: int - b: Annotated[int, pydantic_v2.Field(**dto_field("read-only"))] # pyright: ignore + b: Annotated[int, DTOField("read-only")] c: Annotated[int, pydantic_v2.Field(gt=1)] d: int = pydantic_v2.Field(default=1) e: int = pydantic_v2.Field(default_factory=int_factory) @@ -58,3 +61,33 @@ class NotModel: assert PydanticDTO.detect_nested_field(FieldDefinition.from_annotation(TestModel)) is True assert PydanticDTO.detect_nested_field(FieldDefinition.from_annotation(NotModel)) is False + + +ReadOnlyInt = Annotated[int, DTOField("read-only")] + + +def test_pydantic_dto_annotated_dto_field(base_model: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel]) -> None: + class Model(base_model): # type: ignore[misc, valid-type] + a: Annotated[int, DTOField("read-only")] + b: ReadOnlyInt + + dto_type = PydanticDTO[Model] + fields = list(dto_type.generate_field_definitions(Model)) + assert fields[0].dto_field == DTOField("read-only") + assert fields[1].dto_field == DTOField("read-only") + + +def test_dto_field_via_pydantic_field_extra_deprecation( + pydantic_version: PydanticVersion, +) -> None: + if pydantic_version == "v1": + + class Model(pydantic_v1.BaseModel): # pyright: ignore + a: int = pydantic_v1.Field(**dto_field("read-only")) # type: ignore[arg-type, misc] + else: + + class Model(pydantic_v2.BaseModel): # type: ignore[no-redef] + a: int = pydantic_v2.Field(**dto_field("read-only")) # type: ignore[arg-type, pydantic-field] + + with pytest.warns(DeprecationWarning): + next(PydanticDTO.generate_field_definitions(Model)) diff --git a/tests/unit/test_dto/test_factory/test_dataclass_dto.py b/tests/unit/test_dto/test_factory/test_dataclass_dto.py index 42aea8cc91..7a334a2c39 100644 --- a/tests/unit/test_dto/test_factory/test_dataclass_dto.py +++ b/tests/unit/test_dto/test_factory/test_dataclass_dto.py @@ -6,6 +6,7 @@ from unittest.mock import ANY import pytest +from typing_extensions import Annotated from litestar.dto import DataclassDTO, DTOField from litestar.dto.data_structures import DTOFieldDefinition @@ -121,3 +122,20 @@ def test_dataclass_field_definitions(dto_type: type[DataclassDTO[Model]]) -> Non def test_dataclass_detect_nested(dto_type: type[DataclassDTO[Model]]) -> None: assert dto_type.detect_nested_field(FieldDefinition.from_annotation(Model)) is True assert dto_type.detect_nested_field(FieldDefinition.from_annotation(int)) is False + + +ReadOnlyInt = Annotated[int, DTOField("read-only")] + + +def test_dataclass_dto_annotated_dto_field() -> None: + Annotated[int, DTOField("read-only")] + + @dataclass + class Model: + a: Annotated[int, DTOField("read-only")] + b: ReadOnlyInt + + dto_type = DataclassDTO[Model] + fields = list(dto_type.generate_field_definitions(Model)) + assert fields[0].dto_field == DTOField("read-only") + assert fields[1].dto_field == DTOField("read-only") diff --git a/tests/unit/test_dto/test_factory/test_field.py b/tests/unit/test_dto/test_factory/test_field.py new file mode 100644 index 0000000000..2e107c49d5 --- /dev/null +++ b/tests/unit/test_dto/test_factory/test_field.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import pytest + +from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field +from litestar.exceptions import ImproperlyConfiguredException +from litestar.typing import FieldDefinition + + +def test_extract_dto_field_unexpected_type() -> None: + with pytest.raises(ImproperlyConfiguredException): + extract_dto_field(FieldDefinition.from_annotation(int), {DTO_FIELD_META_KEY: object()})