Skip to content

Commit

Permalink
Feat: support declaring DTOField via Annotated
Browse files Browse the repository at this point in the history
E.g.:

```py
class A(Struct):
    a: Annotated[int, dto_field("read_only")]
```

For #2351
  • Loading branch information
peterschutt committed Sep 27, 2023
1 parent de5edc6 commit 907ebd5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 7 deletions.
5 changes: 3 additions & 2 deletions litestar/contrib/pydantic/pydantic_dto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
from litestar.exceptions import MissingDependencyException, ValidationException
from litestar.types.empty import Empty

Expand Down Expand Up @@ -63,7 +63,8 @@ def generate_field_definitions(

for field_name, field_info in model_fields.items():
field_definition = model_field_definitions[field_name]
dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField())
dto_field = extract_dto_field(field_definition, field_definition.extra)
field_definition.extra.pop(DTO_FIELD_META_KEY, None)

if field_info.default is not PydanticUndefined:
default = field_info.default
Expand Down
4 changes: 2 additions & 2 deletions litestar/dto/dataclass_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
from litestar.dto.field import extract_dto_field
from litestar.params import DependencyKwarg, KwargDefinition
from litestar.types.empty import Empty

Expand Down Expand Up @@ -40,7 +40,7 @@ def generate_field_definitions(
DTOFieldDefinition.from_field_definition(
field_definition=field_definition,
default_factory=default_factory,
dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()),
dto_field=extract_dto_field(field_definition, dc_field.metadata),
model_name=model_type.__name__,
),
name=key,
Expand Down
33 changes: 32 additions & 1 deletion litestar/dto/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

from dataclasses import dataclass
from enum import Enum
from typing import Literal
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Literal, Mapping

from litestar.typing import FieldDefinition

__all__ = (
"DTO_FIELD_META_KEY",
"DTOField",
"Mark",
"dto_field",
"extract_dto_field",
)

DTO_FIELD_META_KEY = "__dto__"
Expand Down Expand Up @@ -47,3 +53,28 @@ def dto_field(mark: Literal["read-only", "private"] | Mark) -> dict[str, DTOFiel
Marking a field automates its inclusion/exclusion from DTO field definitions, depending on the DTO's purpose.
"""
return {DTO_FIELD_META_KEY: DTOField(mark=Mark(mark))}


def extract_dto_field(field_definition: FieldDefinition, field_info_mapping: Mapping[str, Any]) -> DTOField:
"""Extract ``DTOField`` instance for a model field.
Supports ``DTOField`` to bet set via ``Annotated`` or via a field info/metadata mapping.
E.g., ``Annotated[str, DTOField(mark="read-only")]`` or ``info=dto_field(mark="read-only")``.
If a value is found in ``field_info_mapping``, it is prioritized over the field definition's metadata.
Args:
field_definition: A field definition.
field_info_mapping: A field metadata/info attribute mapping, e.g., SQLAlchemy's "info" attribute,
or dataclasses "metadata" attribute.
Returns:
DTO field info, if any.
"""
if inst := field_info_mapping.get(DTO_FIELD_META_KEY):
if not isinstance(inst, DTOField):
raise TypeError(f"DTO field info must be an instance of DTOField, got '{inst}'")
return inst

return next((f for f in field_definition.metadata if isinstance(f, DTOField)), DTOField())
5 changes: 3 additions & 2 deletions litestar/dto/msgspec_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
from litestar.types.empty import Empty

if TYPE_CHECKING:
Expand All @@ -33,7 +33,8 @@ def default_or_empty(value: Any) -> Any:

for key, field_definition in cls.get_model_type_hints(model_type).items():
msgspec_field = msgspec_fields[key]
dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField())
dto_field = extract_dto_field(field_definition, field_definition.extra)
field_definition.extra.pop(DTO_FIELD_META_KEY, None)

yield replace(
DTOFieldDefinition.from_field_definition(
Expand Down

0 comments on commit 907ebd5

Please sign in to comment.