Skip to content

Commit

Permalink
Add validate_strict Pydantic export flag
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Jun 15, 2024
1 parent 65f01ea commit fcc76ab
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
30 changes: 22 additions & 8 deletions litestar/contrib/pydantic/pydantic_init_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from uuid import UUID

Expand Down Expand Up @@ -46,9 +47,9 @@ def _dec_pydantic_v1(model_type: type[pydantic_v1.BaseModel], value: Any) -> pyd
raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e


def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any) -> pydantic_v2.BaseModel:
def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any, strict: bool) -> pydantic_v2.BaseModel:
try:
return model_type.model_validate(value, strict=False)
return model_type.model_validate(value, strict=strict)
except pydantic_v2.ValidationError as e:
raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e

Expand Down Expand Up @@ -123,10 +124,20 @@ def extract(annotation: Any, default: Any) -> Any:


class PydanticInitPlugin(InitPluginProtocol):
__slots__ = ("prefer_alias",)

def __init__(self, prefer_alias: bool = False) -> None:
__slots__ = ("prefer_alias", "validate_strict")

def __init__(
self,
prefer_alias: bool = False,
validate_strict: bool = False,
) -> None:
"""Pydantic Plugin to support serialization / validation of Pydantic types / models
:param prefer_alias: Whether to use the ``by_alias=True`` flag when serializing models
:param validate_strict: Whether to use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models
"""
self.prefer_alias = prefer_alias
self.validate_strict = validate_strict

@classmethod
def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
Expand All @@ -136,13 +147,13 @@ def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]
return encoders

@classmethod
def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
def decoders(cls, validate_strict: bool = False) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [
(is_pydantic_v1_model_class, _dec_pydantic_v1)
]

if pydantic_v2 is not None: # pragma: no cover
decoders.append((is_pydantic_v2_model_class, _dec_pydantic_v2))
decoders.append((is_pydantic_v2_model_class, partial(_dec_pydantic_v2, strict=validate_strict)))

decoders.append((_is_pydantic_v1_uuid, _dec_pydantic_uuid))

Expand Down Expand Up @@ -180,7 +191,10 @@ def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callab

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})}
app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])]
app_config.type_decoders = [
*self.decoders(validate_strict=self.validate_strict),
*(app_config.type_decoders or []),
]

_KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor)
return app_config
32 changes: 32 additions & 0 deletions tests/unit/test_contrib/test_pydantic/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import pydantic as pydantic_v2
import pytest
from pydantic import BaseModel, StrictBool
from pydantic import v1 as pydantic_v1
from typing_extensions import Annotated

from litestar import post
from litestar.contrib.pydantic import PydanticInitPlugin
from litestar.contrib.pydantic.pydantic_dto_factory import PydanticDTO
from litestar.enums import RequestEncodingType
from litestar.params import Body, Parameter
Expand Down Expand Up @@ -305,3 +307,33 @@ async def handler(data: Model) -> Model:
res = client.post("/", json={"foo": in_})
assert res.status_code == 201
assert res.json() == {"foo": in_}


@pytest.mark.parametrize(
"validate_strict,expect_error",
[
(False, False),
(None, False),
(True, True),
],
)
def test_v2_strict_validate(
validate_strict: bool,
expect_error: bool,
) -> None:
# https://github.com/litestar-org/litestar/issues/3572

class Model(BaseModel):
test_bool: StrictBool

@post("/")
async def handler(data: Model) -> None:
return None

plugins = []
if validate_strict is not None:
plugins.append(PydanticInitPlugin(validate_strict=validate_strict))

with create_test_client([handler], plugins=plugins) as client:
res = client.post("/", json={"test_bool": "YES"})
assert res.status_code == 400 if expect_error else 201

0 comments on commit fcc76ab

Please sign in to comment.