Skip to content

Commit

Permalink
Split v2.rest_framework package tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Jan 2, 2024
1 parent ad15ee1 commit 6a674cb
Show file tree
Hide file tree
Showing 10 changed files with 378 additions and 279 deletions.
63 changes: 31 additions & 32 deletions django_pydantic_field/v2/rest_framework/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ class AutoSchema(openapi.AutoSchema):
def __init__(self, tags=None, operation_id_base=None, component_name=None) -> None:
super().__init__(tags, operation_id_base, component_name)
self.collected_schema_defs: dict[str, ty.Any] = {}
self.adapter_type_to_schema_refs = weakref.WeakKeyDictionary[type, str]()
self.collected_adapter_schema_refs: dict[str, ty.Any] = {}
self.adapter_mode: JsonSchemaMode = "validation"
self.rf = APIRequestFactory()

def get_components(self, path: str, method: str) -> dict[str, ty.Any]:
if method.lower() == "delete":
return {}

super().get_components

request_serializer = self.get_request_serializer(path, method) # type: ignore[attr-defined]
response_serializer = self.get_response_serializer(path, method) # type: ignore[attr-defined]

Expand Down Expand Up @@ -61,9 +59,9 @@ def get_request_body(self, path, method):
schema_content = {}

for parser, ct in zip(self.view.parser_classes, self.request_media_types):
if issubclass(parser, parsers.SchemaParser):
ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[parser])
schema_content[ct] = {"schema": {"$ref": ref_path}}
if isinstance(parser(), parsers.SchemaParser):
parser_schema = self.collected_adapter_schema_refs[repr(parser)]
schema_content[ct] = {"schema": parser_schema}
else:
schema_content[ct] = request_schema

Expand All @@ -76,23 +74,21 @@ def get_responses(self, path, method):
self.response_media_types = self.map_renderers(path, method)
serializer = self.get_response_serializer(path, method)

item_schema = {}
response_schema = {}
if isinstance(serializer, serializers.Serializer):
item_schema = self.get_reference(serializer)
response_schema = self.get_reference(serializer)

if drf_schema_utils.is_list_view(path, method, self.view):
response_schema = {"type": "array", "items": item_schema}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
else:
response_schema = item_schema
is_list_view = drf_schema_utils.is_list_view(path, method, self.view)
if is_list_view:
response_schema = self._get_paginated_schema(response_schema)

schema_content = {}
for renderer, ct in zip(self.view.renderer_classes, self.response_media_types):
if issubclass(renderer, renderers.SchemaRenderer):
ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[renderer])
schema_content[ct] = {"schema": {"$ref": ref_path}}
if isinstance(renderer(), renderers.SchemaRenderer):
renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]}
if is_list_view:
renderer_schema = self._get_paginated_schema(renderer_schema)
schema_content[ct] = renderer_schema
else:
schema_content[ct] = response_schema

Expand All @@ -110,14 +106,15 @@ def map_parsers(self, path: str, method: str) -> list[str]:

for parser in self.view.parser_classes:
media_types.append(parser.media_type)
if issubclass(parser, parsers.SchemaParser):
schema_parsers.append(parser())
instance = parser()
if isinstance(instance, parsers.SchemaParser):
schema_parsers.append(parser)

if schema_parsers:
self.adapter_mode = "validation"
request = self.rf.generic(method, path)
schemas = self._collect_adapter_components(schema_parsers, self.view.get_parser_context(request))
self.collected_schema_defs.update(schemas)
self.collected_adapter_schema_refs.update(schemas)

return media_types

Expand All @@ -127,13 +124,14 @@ def map_renderers(self, path: str, method: str) -> list[str]:

for renderer in self.view.renderer_classes:
media_types.append(renderer.media_type)
if issubclass(renderer, renderers.SchemaRenderer):
schema_renderers.append(renderer())
instance = renderer()
if isinstance(instance, renderers.SchemaRenderer):
schema_renderers.append(renderer)

if schema_renderers:
self.adapter_mode = "serialization"
schemas = self._collect_adapter_components(schema_renderers, self.view.get_renderer_context())
self.collected_schema_defs.update(schemas)
self.collected_adapter_schema_refs.update(schemas)

return media_types

Expand All @@ -160,16 +158,13 @@ def _collect_serializer_component(self, serializer: serializers.BaseSerializer |
schema_definition[component_name] = self.map_serializer(serializer)
return schema_definition

def _collect_adapter_components(self, components: Iterable[mixins.AnnotatedAdapterMixin], context: dict):
def _collect_adapter_components(self, components: Iterable[type[mixins.AnnotatedAdapterMixin]], context: dict):
type_adapters = []

for component in components:
schema_adapter = component.get_adapter(context)
schema_adapter = component().get_adapter(context)
if schema_adapter is not None:
schema_name = schema_adapter.prepared_schema.__class__.__name__
self.adapter_type_to_schema_refs[type(component)] = schema_name

type_adapters.append((schema_name, self.adapter_mode, schema_adapter.type_adapter))
type_adapters.append((repr(component), self.adapter_mode, schema_adapter.type_adapter))

if type_adapters:
return self._collect_type_adapter_schemas(type_adapters)
Expand All @@ -186,5 +181,9 @@ def _collect_type_adapter_schemas(self, adapters: Iterable[tuple[str, JsonSchema
self.collected_schema_defs.update(common_schemas.get("$defs", {}))
return inner_schemas

def _get_component_ref(self, model: str):
return self.REF_TEMPLATE_PREFIX.format(model=model)
def _get_paginated_schema(self, schema) -> ty.Any:
response_schema = {"type": "array", "items": schema}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema) # type: ignore
return response_schema
Empty file.
26 changes: 26 additions & 0 deletions tests/v2/rest_framework/test_coreapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys

import pytest
from rest_framework import schemas
from rest_framework.request import Request

from .view_fixtures import create_views_urlconf

coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi")

@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12")
@pytest.mark.parametrize(
"method, path",
[
("GET", "/func"),
("POST", "/func"),
("GET", "/class"),
("PUT", "/class"),
],
)
def test_coreapi_schema_generators(request_factory, method, path):
urlconf = create_views_urlconf(coreapi.AutoSchema)
generator = schemas.SchemaGenerator(urlconf=urlconf)
request = Request(request_factory.generic(method, path))
coreapi_schema = generator.get_schema(request)
assert coreapi_schema
56 changes: 56 additions & 0 deletions tests/v2/rest_framework/test_e2e_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from datetime import date

import pytest

from tests.conftest import InnerSchema

from .view_fixtures import (
ClassBasedView,
ClassBasedViewWithModel,
ClassBasedViewWithSchemaContext,
sample_view,
)

rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework")
coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi")


@pytest.mark.parametrize(
"view",
[
sample_view,
ClassBasedView.as_view(),
ClassBasedViewWithSchemaContext.as_view(),
],
)
def test_end_to_end_api_view(view, request_factory):
expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}'

request = request_factory.post("/", existing_encoded, content_type="application/json")
response = view(request)

assert response.data == [expected_instance]
assert response.data[0] is not expected_instance

assert response.rendered_content == b"[%s]" % existing_encoded


@pytest.mark.django_db
def test_end_to_end_list_create_api_view(request_factory):
field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json()
expected_result = {
"sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1},
"sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}],
"sample_seq": [],
}

payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2)
request = request_factory.post("/", payload.encode(), content_type="application/json")
response = ClassBasedViewWithModel.as_view()(request)

assert response.data == expected_result

request = request_factory.get("/", content_type="application/json")
response = ClassBasedViewWithModel.as_view()(request)
assert response.data == [expected_result]
108 changes: 108 additions & 0 deletions tests/v2/rest_framework/test_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import typing as ty
from datetime import date

import pytest
from rest_framework import exceptions, serializers

from tests.conftest import InnerSchema
from tests.test_app.models import SampleModel

rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework")


class SampleSerializer(serializers.Serializer):
field = rest_framework.SchemaField(schema=ty.List[InnerSchema])


class SampleModelSerializer(serializers.ModelSerializer):
sample_field = rest_framework.SchemaField(schema=InnerSchema)
sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema])
sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list)

class Meta:
model = SampleModel
fields = "sample_field", "sample_list", "sample_seq"


def test_schema_field():
field = rest_framework.SchemaField(InnerSchema)
existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
expected_encoded = {
"stub_str": "abc",
"stub_int": 1,
"stub_list": [date(2022, 7, 1)],
}

assert field.to_representation(existing_instance) == expected_encoded
assert field.to_internal_value(expected_encoded) == existing_instance

with pytest.raises(serializers.ValidationError):
field.to_internal_value(None)

with pytest.raises(serializers.ValidationError):
field.to_internal_value("null")


def test_field_schema_with_custom_config():
field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"})
existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}

assert field.to_representation(existing_instance) == expected_encoded
assert field.to_internal_value(expected_encoded) == existing_instance
assert field.to_internal_value(None) is None
assert field.to_internal_value("null") is None


def test_serializer_marshalling_with_schema_field():
existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]}
expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]}

serializer = SampleSerializer(instance=existing_instance)
assert serializer.data == expected_data

serializer = SampleSerializer(data=expected_data)
serializer.is_valid(raise_exception=True)
assert serializer.validated_data == existing_instance


def test_model_serializer_marshalling_with_schema_field():
instance = SampleModel(
sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]),
sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2,
sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3,
)
serializer = SampleModelSerializer(instance)

expected_data = {
"sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]},
"sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2,
"sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3,
}
assert serializer.data == expected_data


@pytest.mark.parametrize(
"export_kwargs",
[
{"include": {"stub_str", "stub_int"}},
{"exclude": {"stub_list"}},
{"exclude_unset": True},
{"exclude_defaults": True},
{"exclude_none": True},
{"by_alias": True},
],
)
def test_field_export_kwargs(export_kwargs):
field = rest_framework.SchemaField(InnerSchema, **export_kwargs)
assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]))


def test_invalid_data_serialization():
invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]}
serializer = SampleSerializer(data=invalid_data)

with pytest.raises(exceptions.ValidationError) as e:
serializer.is_valid(raise_exception=True)

assert e.match(r".*stub_str.*stub_int.*stub_list.*")
23 changes: 23 additions & 0 deletions tests/v2/rest_framework/test_openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from rest_framework.schemas.openapi import SchemaGenerator
from rest_framework.request import Request

from .view_fixtures import create_views_urlconf

openapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.openapi")

@pytest.mark.parametrize(
"method, path",
[
("GET", "/func"),
("POST", "/func"),
("GET", "/class"),
("PUT", "/class"),
],
)
def test_coreapi_schema_generators(request_factory, method, path):
urlconf = create_views_urlconf(openapi.AutoSchema)
generator = SchemaGenerator(urlconf=urlconf)
request = Request(request_factory.generic(method, path))
openapi_schema = generator.get_schema(request)
assert openapi_schema
23 changes: 23 additions & 0 deletions tests/v2/rest_framework/test_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import io
from datetime import date

import pytest

from tests.conftest import InnerSchema

rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework")


@pytest.mark.parametrize(
"schema_type, existing_encoded, expected_decoded",
[
(
InnerSchema,
'{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}',
InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]),
)
],
)
def test_schema_parser(schema_type, existing_encoded, expected_decoded):
parser = rest_framework.SchemaParser[schema_type]()
assert parser.parse(io.StringIO(existing_encoded)) == expected_decoded
23 changes: 23 additions & 0 deletions tests/v2/rest_framework/test_renderers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import date

import pytest

from tests.conftest import InnerSchema

rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework")


def test_schema_renderer():
renderer = rest_framework.SchemaRenderer()
existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}'

assert renderer.render(existing_instance) == expected_encoded


def test_typed_schema_renderer():
renderer = rest_framework.SchemaRenderer[InnerSchema]()
existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}
expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}'

assert renderer.render(existing_data) == expected_encoded
Loading

0 comments on commit 6a674cb

Please sign in to comment.