Skip to content

Commit

Permalink
Ensure import resolution for compatibility layer
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Feb 2, 2024
1 parent 6ab69ba commit 233f0a3
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 13 deletions.
10 changes: 9 additions & 1 deletion django_pydantic_field/compat/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ def compat_getattr(module_name: str):

def compat_dir(module_name: str):
compat_module = _import_compat_module(module_name)
return dir(compat_module)
module_ns = vars(compat_module)

if "__dir__" in module_ns:
return module_ns["__dir__"]

if "__all__" in module_ns:
return functools.partial(list, module_ns["__all__"])

return functools.partial(dir, compat_module)


def _import_compat_module(module_name: str) -> types.ModuleType:
Expand Down
16 changes: 10 additions & 6 deletions django_pydantic_field/rest_framework.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import typing as ty
import typing_extensions as te

import typing_extensions as te
from django.utils.functional import _StrOrPromise
from rest_framework import parsers, renderers
from rest_framework.fields import _DefaultInitial, Field
from rest_framework.schemas.openapi import AutoSchema as _OpenAPIAutoSchema
from rest_framework.validators import Validator

from django.utils.functional import _StrOrPromise

from .fields import ST, ConfigType, _ExportKwargs
from .fields import _ExportKwargs, ConfigType, ST

__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer")
__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer", "AutoSchema")

class _FieldKwargs(te.TypedDict, ty.Generic[ST], total=False):
read_only: bool
Expand Down Expand Up @@ -45,7 +45,9 @@ class SchemaField(Field, ty.Generic[ST]):
**kwargs: te.Unpack[_SchemaFieldKwargs[ST]],
) -> None: ...
@ty.overload
@te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.")
@te.deprecated(
"Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions."
)
def __init__(
self,
schema: ty.Type[ST] | ty.ForwardRef | str,
Expand All @@ -61,3 +63,5 @@ class SchemaParser(parsers.JSONParser, ty.Generic[ST]):
class SchemaRenderer(renderers.JSONRenderer, ty.Generic[ST]):
schema_context_key: ty.ClassVar[str]
config_context_key: ty.ClassVar[str]

class AutoSchema(_OpenAPIAutoSchema): ...
2 changes: 2 additions & 0 deletions django_pydantic_field/v2/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from django_pydantic_field.compat import deprecation
from . import types

__all__ = ("SchemaField",)


class SchemaField(JSONField, ty.Generic[types.ST]):
adapter: types.SchemaAdapter
Expand Down
16 changes: 15 additions & 1 deletion django_pydantic_field/v2/rest_framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from django_pydantic_field.compat import PYDANTIC_V2

from . import coreapi as coreapi
from . import openapi as openapi
from .fields import SchemaField as SchemaField
from .parsers import SchemaParser as SchemaParser
from .renderers import SchemaRenderer as SchemaRenderer
Expand All @@ -8,8 +12,18 @@
"or `django_pydantic_field.rest_framework.coreapi.AutoSchema` instead."
)

__all__ = (
"coreapi",
"openapi",
"SchemaField",
"SchemaParser",
"SchemaRenderer",
"AutoSchema", # type: ignore
)


def __getattr__(key):
if key == "AutoSchema":
if key == "AutoSchema" and PYDANTIC_V2:
import warnings

from .openapi import AutoSchema
Expand Down
10 changes: 5 additions & 5 deletions django_pydantic_field/v2/rest_framework/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework.test import APIRequestFactory

from . import fields, parsers, renderers
from ..utils import get_origin_type
from django_pydantic_field.v2 import utils

if ty.TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_request_body(self, path, method):
schema_content = {}

for parser, ct in zip(self.view.parser_classes, self.request_media_types):
if issubclass(get_origin_type(parser), parsers.SchemaParser):
if issubclass(utils.get_origin_type(parser), parsers.SchemaParser):
parser_schema = self.collected_adapter_schema_refs[repr(parser)]
else:
parser_schema = request_schema
Expand All @@ -86,7 +86,7 @@ def get_responses(self, path, method):

schema_content = {}
for renderer, ct in zip(self.view.renderer_classes, self.response_media_types):
if issubclass(get_origin_type(renderer), renderers.SchemaRenderer):
if issubclass(utils.get_origin_type(renderer), renderers.SchemaRenderer):
renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]}
if is_list_view:
renderer_schema = self._get_paginated_schema(renderer_schema)
Expand All @@ -108,7 +108,7 @@ def map_parsers(self, path: str, method: str) -> list[str]:

for parser in self.view.parser_classes:
media_types.append(parser.media_type)
if issubclass(get_origin_type(parser), parsers.SchemaParser):
if issubclass(utils.get_origin_type(parser), parsers.SchemaParser):
schema_parsers.append(parser)

if schema_parsers:
Expand All @@ -125,7 +125,7 @@ def map_renderers(self, path: str, method: str) -> list[str]:

for renderer in self.view.renderer_classes:
media_types.append(renderer.media_type)
if issubclass(get_origin_type(renderer), renderers.SchemaRenderer):
if issubclass(utils.get_origin_type(renderer), renderers.SchemaRenderer):
schema_renderers.append(renderer)

if schema_renderers:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from .test_app.models import SampleForwardRefModel, SampleModel, SampleSchema


@pytest.mark.parametrize(
"exported_primitive_name",
["SchemaField"],
)
def test_module_imports(exported_primitive_name):
assert exported_primitive_name in dir(fields)
assert getattr(fields, exported_primitive_name, None) is not None


def test_sample_field():
sample_field = fields.PydanticSchemaField(schema=InnerSchema)
existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
Expand Down
53 changes: 53 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import warnings

import pytest

import django_pydantic_field
from django_pydantic_field import fields, forms, rest_framework
from django_pydantic_field.compat import PYDANTIC_V1, PYDANTIC_V2


@pytest.mark.parametrize(
"module, exported_primitive_name",
[
(django_pydantic_field, "SchemaField"),
(fields, "SchemaField"),
(forms, "SchemaField"),
(rest_framework, "SchemaParser"),
(rest_framework, "SchemaRenderer"),
(rest_framework, "SchemaField"),
(rest_framework, "AutoSchema"),
pytest.param(
rest_framework,
"openapi",
marks=pytest.mark.skipif(
not PYDANTIC_V2,
reason="`.rest_framework.openapi` module is only appearing in v2 layer",
),
),
pytest.param(
rest_framework,
"coreapi",
marks=pytest.mark.skipif(
not PYDANTIC_V2,
reason="`.rest_framework.coreapi` module is only appearing in v2 layer",
),
),
],
)
def test_module_imports(module, exported_primitive_name):
assert exported_primitive_name in dir(module)
assert getattr(module, exported_primitive_name, None) is not None


@pytest.mark.skipif(not PYDANTIC_V2, reason="AutoSchema import warning is only appearing in v2 layer")
def test_rest_framework_autoschema_warning_v2():
with pytest.deprecated_call(match="`django_pydantic_field.rest_framework.AutoSchema` is deprecated.*"):
rest_framework.AutoSchema


@pytest.mark.skipif(not PYDANTIC_V1, reason="Deprecation warning should not be raised in v1 layer")
def test_rest_framework_autoschema_no_warning_v1():
with warnings.catch_warnings():
warnings.simplefilter("error")
rest_framework.AutoSchema

0 comments on commit 233f0a3

Please sign in to comment.