Skip to content

Commit

Permalink
feat(DTO): Support extra="forbid" model config for PydanticDTO (#…
Browse files Browse the repository at this point in the history
…3691)

* feat(DTO): Support extra="forbid" config for PydanticDTO
  • Loading branch information
provinzkraut authored Aug 24, 2024
1 parent 9cff0a4 commit 60e17cd
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 2 deletions.
15 changes: 14 additions & 1 deletion litestar/contrib/pydantic/pydantic_dto_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import dataclasses
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Collection, Generic, TypeVar
from warnings import warn

from typing_extensions import Annotated, TypeAlias, override

from litestar.contrib.pydantic.utils import is_pydantic_undefined, is_pydantic_v2
from litestar.contrib.pydantic.utils import is_pydantic_2_model, is_pydantic_undefined, is_pydantic_v2
from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
Expand All @@ -17,6 +18,8 @@
if TYPE_CHECKING:
from typing import Generator

from litestar.dto import DTOConfig

try:
import pydantic as _ # noqa: F401
except ImportError as e:
Expand Down Expand Up @@ -160,3 +163,13 @@ def detect_nested_field(cls, field_definition: FieldDefinition) -> bool:
if pydantic_v2 is not Empty: # type: ignore[comparison-overlap]
return field_definition.is_subclass_of((pydantic_v1.BaseModel, pydantic_v2.BaseModel))
return field_definition.is_subclass_of(pydantic_v1.BaseModel) # type: ignore[unreachable]

@classmethod
def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> DTOConfig:
if is_pydantic_2_model(model_type) and (model_config := getattr(model_type, "model_config", None)):
if model_config.get("extra") == "forbid":
config = dataclasses.replace(config, forbid_unknown_fields=True)
elif issubclass(model_type, pydantic_v1.BaseModel) and (model_config := getattr(model_type, "Config", None)): # noqa: SIM102
if getattr(model_config, "extra", None) == "forbid":
config = dataclasses.replace(config, forbid_unknown_fields=True)
return config
18 changes: 18 additions & 0 deletions litestar/dto/base_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ def __class_getitem__(cls, annotation: Any) -> type[Self]:

return type(f"{cls.__name__}[{annotation}]", (cls,), cls_dict) # pyright: ignore

def __init_subclass__(cls, **kwargs: Any) -> None:
if (config := getattr(cls, "config", None)) and (model_type := getattr(cls, "model_type", None)):
# it's a concrete class
cls.config = cls.get_config_for_model_type(config, model_type)

@classmethod
def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> DTOConfig:
"""Create a new configuration for this specific ``model_type``, during the
creation of the factory.
The returned config object will be set as the ``config`` attribute on the newly
defined factory class.
.. versionadded: 2.11
"""
return config

def decode_builtins(self, value: dict[str, Any]) -> Any:
"""Decode a dictionary of Python values into an the DTO's datatype."""

Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_contrib/test_pydantic/test_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from typing import TYPE_CHECKING, Optional, cast

import pydantic as pydantic_v2
import pytest
from pydantic import v1 as pydantic_v1
from typing_extensions import Annotated, Literal

from litestar import Request, post
from litestar.contrib.pydantic import PydanticDTO, _model_dump_json
Expand Down Expand Up @@ -100,3 +102,71 @@ def get_user() -> User:
component_schema = schema.components.schemas["GetUserUserResponseBody"]
assert component_schema.properties is not None
assert component_schema.properties["id"].description == "This is a test (id description)."


@pytest.mark.parametrize(
"model_config_option, forbid_unknown_fields_default, expected_dto_config_option",
[
("forbid", False, True),
("forbid", True, True),
("allow", False, False),
("allow", True, True),
("ignore", True, True),
("ignore", False, False),
],
)
def test_forbid_unknown_fields_if_forbid_extra_is_set_v1(
use_experimental_dto_backend: bool,
forbid_unknown_fields_default: bool,
model_config_option: Literal["forbid", "allow", "ignore"],
expected_dto_config_option: bool,
) -> None:
class Model(pydantic_v1.BaseModel):
class Config:
extra = model_config_option

a: str

dto_config = DTOConfig(
experimental_codegen_backend=use_experimental_dto_backend,
# config set on the pydantic model should take precedence
forbid_unknown_fields=forbid_unknown_fields_default,
)
dto = PydanticDTO[Annotated[Model, dto_config]]

assert dto.config.forbid_unknown_fields is expected_dto_config_option
# ensure the config is merged
assert dto.config.experimental_codegen_backend is use_experimental_dto_backend


@pytest.mark.parametrize(
"model_config_option, forbid_unknown_fields_default, expected_dto_config_option",
[
("forbid", False, True),
("forbid", True, True),
("allow", False, False),
("allow", True, True),
("ignore", True, True),
("ignore", False, False),
],
)
def test_forbid_unknown_fields_if_forbid_extra_is_set_v2(
use_experimental_dto_backend: bool,
forbid_unknown_fields_default: bool,
model_config_option: Literal["forbid", "allow", "ignore"],
expected_dto_config_option: bool,
) -> None:
class Model(pydantic_v2.BaseModel):
a: str
model_config = pydantic_v2.ConfigDict(extra=model_config_option)

dto_config = DTOConfig(
experimental_codegen_backend=use_experimental_dto_backend,
# config set on the pydantic model should take precedence
forbid_unknown_fields=forbid_unknown_fields_default,
)
dto = PydanticDTO[Annotated[Model, dto_config]]

assert dto.config.forbid_unknown_fields is expected_dto_config_option
# ensure the config is merged
assert dto.config.experimental_codegen_backend is use_experimental_dto_backend
31 changes: 30 additions & 1 deletion tests/unit/test_dto/test_factory/test_base_dto.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# ruff: noqa: UP006
from __future__ import annotations

import dataclasses
from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Generic, Tuple, TypeVar, Union

import pytest
from typing_extensions import Annotated
Expand Down Expand Up @@ -160,3 +161,31 @@ class SubType(Model):
assert (
dto_type._dto_backends["handler_id"]["data_backend"].parsed_field_definitions[-1].name == "c" # pyright: ignore
)


def test_get_config_for_model_type() -> None:
base_config = DTOConfig(rename_strategy="camel")

class CustomDTO(DataclassDTO[T], Generic[T]):
@classmethod
def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> DTOConfig:
return dataclasses.replace(config, exclude={"foo"})

annotated_dto = CustomDTO[Model]
annotated_dto_with_config = CustomDTO[Annotated[Model, base_config]]

class SubclassDTO(CustomDTO[Model]):
pass

class SubclassDTOWithConfig(CustomDTO[Model]):
config = base_config

assert annotated_dto.config.exclude == {"foo"}
assert SubclassDTO.config.exclude == {"foo"}

# we expect existing configs to have been merged
assert annotated_dto_with_config.config.exclude == {"foo"}
assert annotated_dto_with_config.config.rename_strategy == "camel"

assert SubclassDTOWithConfig.config.exclude == {"foo"}
assert SubclassDTOWithConfig.config.rename_strategy == "camel"

0 comments on commit 60e17cd

Please sign in to comment.