diff --git a/django_pydantic_field/compat/imports.py b/django_pydantic_field/compat/imports.py index 1587d91..ccb6c71 100644 --- a/django_pydantic_field/compat/imports.py +++ b/django_pydantic_field/compat/imports.py @@ -14,7 +14,15 @@ def compat_getattr(module_name: str): def compat_dir(module_name: str): compat_module = _import_compat_module(module_name) - return dir(compat_module) + module_ns = vars(compat_module) + + if "__dir__" in module_ns: + return module_ns["__dir__"] + + if "__all__" in module_ns: + return functools.partial(list, module_ns["__all__"]) + + return functools.partial(dir, compat_module) def _import_compat_module(module_name: str) -> types.ModuleType: diff --git a/django_pydantic_field/rest_framework.pyi b/django_pydantic_field/rest_framework.pyi index e4a3b87..e7e5acb 100644 --- a/django_pydantic_field/rest_framework.pyi +++ b/django_pydantic_field/rest_framework.pyi @@ -1,15 +1,15 @@ import typing as ty -import typing_extensions as te +import typing_extensions as te +from django.utils.functional import _StrOrPromise from rest_framework import parsers, renderers from rest_framework.fields import _DefaultInitial, Field +from rest_framework.schemas.openapi import AutoSchema as _OpenAPIAutoSchema from rest_framework.validators import Validator -from django.utils.functional import _StrOrPromise - -from .fields import ST, ConfigType, _ExportKwargs +from .fields import _ExportKwargs, ConfigType, ST -__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer") +__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer", "AutoSchema") class _FieldKwargs(te.TypedDict, ty.Generic[ST], total=False): read_only: bool @@ -45,7 +45,9 @@ class SchemaField(Field, ty.Generic[ST]): **kwargs: te.Unpack[_SchemaFieldKwargs[ST]], ) -> None: ... @ty.overload - @te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") + @te.deprecated( + "Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions." + ) def __init__( self, schema: ty.Type[ST] | ty.ForwardRef | str, @@ -61,3 +63,5 @@ class SchemaParser(parsers.JSONParser, ty.Generic[ST]): class SchemaRenderer(renderers.JSONRenderer, ty.Generic[ST]): schema_context_key: ty.ClassVar[str] config_context_key: ty.ClassVar[str] + +class AutoSchema(_OpenAPIAutoSchema): ... diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index 510682c..cb2b2eb 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -10,6 +10,8 @@ from django_pydantic_field.compat import deprecation from . import types +__all__ = ("SchemaField",) + class SchemaField(JSONField, ty.Generic[types.ST]): adapter: types.SchemaAdapter diff --git a/django_pydantic_field/v2/rest_framework/__init__.py b/django_pydantic_field/v2/rest_framework/__init__.py index f151648..a607818 100644 --- a/django_pydantic_field/v2/rest_framework/__init__.py +++ b/django_pydantic_field/v2/rest_framework/__init__.py @@ -1,3 +1,7 @@ +from django_pydantic_field.compat import PYDANTIC_V2 + +from . import coreapi as coreapi +from . import openapi as openapi from .fields import SchemaField as SchemaField from .parsers import SchemaParser as SchemaParser from .renderers import SchemaRenderer as SchemaRenderer @@ -8,8 +12,18 @@ "or `django_pydantic_field.rest_framework.coreapi.AutoSchema` instead." ) +__all__ = ( + "coreapi", + "openapi", + "SchemaField", + "SchemaParser", + "SchemaRenderer", + "AutoSchema", # type: ignore +) + + def __getattr__(key): - if key == "AutoSchema": + if key == "AutoSchema" and PYDANTIC_V2: import warnings from .openapi import AutoSchema diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py index 4ccb265..6faa5fa 100644 --- a/django_pydantic_field/v2/rest_framework/openapi.py +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -9,7 +9,7 @@ from rest_framework.test import APIRequestFactory from . import fields, parsers, renderers -from ..utils import get_origin_type +from django_pydantic_field.v2 import utils if ty.TYPE_CHECKING: from collections.abc import Iterable @@ -60,7 +60,7 @@ def get_request_body(self, path, method): schema_content = {} for parser, ct in zip(self.view.parser_classes, self.request_media_types): - if issubclass(get_origin_type(parser), parsers.SchemaParser): + if issubclass(utils.get_origin_type(parser), parsers.SchemaParser): parser_schema = self.collected_adapter_schema_refs[repr(parser)] else: parser_schema = request_schema @@ -86,7 +86,7 @@ def get_responses(self, path, method): schema_content = {} for renderer, ct in zip(self.view.renderer_classes, self.response_media_types): - if issubclass(get_origin_type(renderer), renderers.SchemaRenderer): + if issubclass(utils.get_origin_type(renderer), renderers.SchemaRenderer): renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]} if is_list_view: renderer_schema = self._get_paginated_schema(renderer_schema) @@ -108,7 +108,7 @@ def map_parsers(self, path: str, method: str) -> list[str]: for parser in self.view.parser_classes: media_types.append(parser.media_type) - if issubclass(get_origin_type(parser), parsers.SchemaParser): + if issubclass(utils.get_origin_type(parser), parsers.SchemaParser): schema_parsers.append(parser) if schema_parsers: @@ -125,7 +125,7 @@ def map_renderers(self, path: str, method: str) -> list[str]: for renderer in self.view.renderer_classes: media_types.append(renderer.media_type) - if issubclass(get_origin_type(renderer), renderers.SchemaRenderer): + if issubclass(utils.get_origin_type(renderer), renderers.SchemaRenderer): schema_renderers.append(renderer) if schema_renderers: diff --git a/tests/test_fields.py b/tests/test_fields.py index 621e9aa..d274c8e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -19,6 +19,15 @@ from .test_app.models import SampleForwardRefModel, SampleModel, SampleSchema +@pytest.mark.parametrize( + "exported_primitive_name", + ["SchemaField"], +) +def test_module_imports(exported_primitive_name): + assert exported_primitive_name in dir(fields) + assert getattr(fields, exported_primitive_name, None) is not None + + def test_sample_field(): sample_field = fields.PydanticSchemaField(schema=InnerSchema) existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..29fdf6e --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,53 @@ +import warnings + +import pytest + +import django_pydantic_field +from django_pydantic_field import fields, forms, rest_framework +from django_pydantic_field.compat import PYDANTIC_V1, PYDANTIC_V2 + + +@pytest.mark.parametrize( + "module, exported_primitive_name", + [ + (django_pydantic_field, "SchemaField"), + (fields, "SchemaField"), + (forms, "SchemaField"), + (rest_framework, "SchemaParser"), + (rest_framework, "SchemaRenderer"), + (rest_framework, "SchemaField"), + (rest_framework, "AutoSchema"), + pytest.param( + rest_framework, + "openapi", + marks=pytest.mark.skipif( + not PYDANTIC_V2, + reason="`.rest_framework.openapi` module is only appearing in v2 layer", + ), + ), + pytest.param( + rest_framework, + "coreapi", + marks=pytest.mark.skipif( + not PYDANTIC_V2, + reason="`.rest_framework.coreapi` module is only appearing in v2 layer", + ), + ), + ], +) +def test_module_imports(module, exported_primitive_name): + assert exported_primitive_name in dir(module) + assert getattr(module, exported_primitive_name, None) is not None + + +@pytest.mark.skipif(not PYDANTIC_V2, reason="AutoSchema import warning is only appearing in v2 layer") +def test_rest_framework_autoschema_warning_v2(): + with pytest.deprecated_call(match="`django_pydantic_field.rest_framework.AutoSchema` is deprecated.*"): + rest_framework.AutoSchema + + +@pytest.mark.skipif(not PYDANTIC_V1, reason="Deprecation warning should not be raised in v1 layer") +def test_rest_framework_autoschema_no_warning_v1(): + with warnings.catch_warnings(): + warnings.simplefilter("error") + rest_framework.AutoSchema