Skip to content

Commit

Permalink
Implement containers for field info and dataclass serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Mar 25, 2024
1 parent 77fd863 commit 251e641
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 114 deletions.
242 changes: 203 additions & 39 deletions django_pydantic_field/compat/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,81 +15,209 @@
only with `X | Y` syntax. Both cases require a dedicated serializer for migration writes.
"""

from __future__ import annotations

import abc
import dataclasses
import sys
import types
import typing as ty

import typing_extensions as te
from django.db.migrations.serializer import BaseSerializer, serializer_factory
from django.db.migrations.writer import MigrationWriter
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined

from .typing import get_args, get_origin

try:
from pydantic._internal._repr import Representation
from pydantic.fields import _DefaultValues as FieldInfoDefaultValues
from pydantic_core import PydanticUndefined
except ImportError:
# Assuming this is a Pydantic v1
from pydantic.fields import Undefined as PydanticUndefined
from pydantic.utils import Representation

FieldInfoDefaultValues = FieldInfo.__field_constraints__


class GenericContainer:
class BaseContainer(abc.ABC):
__slot__ = ()

@classmethod
def unwrap(cls, value):
if isinstance(value, BaseContainer) and type(value) is not BaseContainer:
return value.unwrap(value)
return value

def __eq__(self, other):
if isinstance(other, self.__class__):
return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__)
return NotImplemented

def __str__(self):
return repr(self.unwrap(self))

def __repr__(self):
attrs = tuple(getattr(self, attr) for attr in self.__slots__)
return f"{self.__class__.__name__}{attrs}"


class GenericContainer(BaseContainer):
__slots__ = "origin", "args"

def __init__(self, origin, args: tuple = ()):
self.origin = origin
self.args = args

@classmethod
def wrap(cls, typ_):
if isinstance(typ_, GenericTypes):
wrapped_args = tuple(map(cls.wrap, get_args(typ_)))
return cls(get_origin(typ_), wrapped_args)
return typ_
def wrap(cls, value):
if isinstance(value, GenericTypes):
wrapped_args = tuple(map(cls.wrap, get_args(value)))
return cls(get_origin(value), wrapped_args)
if isinstance(value, FieldInfo):
return FieldInfoContainer.wrap(value)
return value

@classmethod
def unwrap(cls, type_):
if not isinstance(type_, cls):
return type_
def unwrap(cls, value):
if not isinstance(value, cls):
return value

if not type_.args:
return type_.origin
if not value.args:
return value.origin

unwrapped_args = tuple(map(cls.unwrap, type_.args))
unwrapped_args = tuple(map(BaseContainer.unwrap, value.args))
try:
# This is a fallback for Python < 3.8, please be careful with that
return type_.origin[unwrapped_args]
return value.origin[unwrapped_args]
except TypeError:
return GenericAlias(type_.origin, unwrapped_args)
return GenericAlias(value.origin, unwrapped_args)

def __repr__(self):
return repr(self.unwrap(self))
def __eq__(self, other):
if isinstance(other, GenericTypes):
return self == self.wrap(other)
return super().__eq__(other)


class DataclassContainer(BaseContainer):
__slots__ = "datacls", "kwargs"

def __init__(self, datacls: type, kwargs: ty.Dict[str, ty.Any]):
self.datacls = datacls
self.kwargs = kwargs

@classmethod
def wrap(cls, value):
if cls._is_dataclass_instance(value):
return cls(type(value), dataclasses.asdict(value))
if isinstance(value, GenericTypes):
return GenericContainer.wrap(value)
return value

__str__ = __repr__
@classmethod
def unwrap(cls, value):
if isinstance(value, cls):
return value.datacls(**value.kwargs)
return value

@staticmethod
def _is_dataclass_instance(obj: ty.Any):
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.origin == other.origin and self.args == other.args
if isinstance(other, GenericTypes):
if self._is_dataclass_instance(other):
return self == self.wrap(other)
return NotImplemented
return super().__eq__(other)


class FieldInfoContainer(BaseContainer):
__slots__ = "origin", "metadata", "kwargs"

class GenericSerializer(BaseSerializer):
value: GenericContainer
def __init__(self, origin, metadata, kwargs):
self.origin = origin
self.metadata = metadata
self.kwargs = kwargs

@classmethod
def wrap(cls, field: FieldInfo):
if not isinstance(field, FieldInfo):
return field

# `getattr` is important to preserve compatibility with Pydantic v1
origin = GenericContainer.wrap(getattr(field, "annotation", None))
metadata = getattr(field, "metadata", ())
metadata = tuple(map(DataclassContainer.wrap, metadata))

kwargs = dict(cls._iter_field_attrs(field))
return cls(origin, metadata, kwargs)

@classmethod
def unwrap(cls, value):
if not isinstance(value, cls):
return value

origin = GenericContainer.unwrap(value.origin)
metadata = tuple(map(BaseContainer.unwrap, value.metadata))
annotation = te._AnnotatedAlias(origin, metadata)
return FieldInfo(annotation=annotation, **value.kwargs)

def __eq__(self, other):
if isinstance(other, FieldInfo):
return self == self.wrap(other)
return super().__eq__(other)

@staticmethod
def _iter_field_attrs(field: FieldInfo):
available_attrs = set(field.__slots__) - {"annotation", "metadata", "_attributes_set"}

for attr in available_attrs:
attr_value = getattr(field, attr)
if attr_value is not PydanticUndefined and attr_value != FieldInfoDefaultValues.get(attr):
yield attr, getattr(field, attr)

@staticmethod
def _wrap_metadata_object(obj):
return DataclassContainer.wrap(obj)


class BaseContainerSerializer(BaseSerializer):
value: BaseContainer

def serialize(self):
value = self.value
tp_repr, imports = serializer_factory(type(self.value)).serialize()
attrs = []

for attr in self._iter_container_attrs():
attr_repr, attr_imports = serializer_factory(attr).serialize()
attrs.append(attr_repr)
imports.update(attr_imports)

attrs_repr = ", ".join(attrs)
return f"{tp_repr}({attrs_repr})", imports

tp_repr, imports = serializer_factory(type(value)).serialize()
orig_repr, orig_imports = serializer_factory(value.origin).serialize()
imports.update(orig_imports)
def _iter_container_attrs(self):
container = self.value
for attr in container.__slots__:
yield getattr(container, attr)

args = []
for arg in value.args:
arg_repr, arg_imports = serializer_factory(arg).serialize()
args.append(arg_repr)
imports.update(arg_imports)

if args:
args_repr = ", ".join(args)
generic_repr = "%s(%s, (%s,))" % (tp_repr, orig_repr, args_repr)
else:
generic_repr = "%s(%s)" % (tp_repr, orig_repr)
class DataclassContainerSerializer(BaseSerializer):
value: DataclassContainer

return generic_repr, imports
def serialize(self):
tp_repr, imports = serializer_factory(self.value.datacls).serialize()

kwarg_pairs = []
for arg, value in self.value.kwargs.items():
value_repr, value_imports = serializer_factory(value).serialize()
kwarg_pairs.append(f"{arg}={value_repr}")
imports.update(value_imports)

kwargs_repr = ", ".join(kwarg_pairs)
return f"{tp_repr}({kwargs_repr})", imports


class TypingSerializer(BaseSerializer):
Expand All @@ -103,6 +231,34 @@ def serialize(self):
return orig_repr, {f"import {orig_module}"}


class FieldInfoSerializer(BaseSerializer):
value: FieldInfo

def serialize(self):
container = FieldInfoContainer.wrap(self.value)
return serializer_factory(container).serialize()


class RepresentationSerializer(BaseSerializer):
value: Representation

def serialize(self):
tp_repr, imports = serializer_factory(type(self.value)).serialize()
repr_args = []

for arg_name, arg_value in self.value.__repr_args__():
arg_value_repr, arg_value_imports = serializer_factory(arg_value).serialize()
imports.update(arg_value_imports)

if arg_name is None:
repr_args.append(arg_value_repr)
else:
repr_args.append(f"{arg_name}={arg_value_repr}")

final_args_repr = ", ".join(repr_args)
return f"{tp_repr}({final_args_repr})"


if sys.version_info >= (3, 9):
GenericAlias = types.GenericAlias
GenericTypes: ty.Tuple[ty.Any, ...] = (
Expand All @@ -117,7 +273,15 @@ def serialize(self):
GenericTypes = GenericAlias, type(ty.List) # noqa


MigrationWriter.register_serializer(GenericContainer, GenericSerializer)
# BaseContainerSerializer *must be* registered after all specialized container serializers
MigrationWriter.register_serializer(DataclassContainer, DataclassContainerSerializer)
MigrationWriter.register_serializer(BaseContainer, BaseContainerSerializer)

# Pydantic-specific datastructures serializers
MigrationWriter.register_serializer(FieldInfo, FieldInfoSerializer)
MigrationWriter.register_serializer(Representation, RepresentationSerializer)

# Typing serializers
MigrationWriter.register_serializer(ty.ForwardRef, TypingSerializer)
MigrationWriter.register_serializer(type(ty.Union), TypingSerializer) # type: ignore

Expand Down
6 changes: 3 additions & 3 deletions django_pydantic_field/v1/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from django.db.models.fields.json import JSONField
from django.db.models.query_utils import DeferredAttribute

from django_pydantic_field.compat.django import GenericContainer
from django_pydantic_field.compat.django import BaseContainer, GenericContainer

from . import base, forms, utils

Expand Down Expand Up @@ -44,7 +44,7 @@ class PydanticSchemaField(JSONField, t.Generic[base.ST]):
def __init__(
self,
*args,
schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None,
schema: t.Union[t.Type["base.ST"], "BaseContainer", "t.ForwardRef", str, None] = None,
config: t.Optional["base.ConfigType"] = None,
**kwargs,
):
Expand Down Expand Up @@ -137,7 +137,7 @@ def value_to_string(self, obj):
return self.get_prep_value(value)

def _resolve_schema(self, schema):
schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema))
schema = t.cast(t.Type["base.ST"], BaseContainer.unwrap(schema))

self.schema = schema
if schema is not None:
Expand Down
10 changes: 5 additions & 5 deletions django_pydantic_field/v2/fields.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import typing as ty
import typing_extensions as te

import pydantic
import typing_extensions as te
from django.core import checks, exceptions
from django.core.serializers.json import DjangoJSONEncoder
from django.db.models.expressions import BaseExpression, Col, Value
Expand All @@ -13,7 +13,7 @@
from django.db.models.query_utils import DeferredAttribute

from django_pydantic_field.compat import deprecation
from django_pydantic_field.compat.django import GenericContainer
from django_pydantic_field.compat.django import BaseContainer, GenericContainer

from . import forms, types

Expand Down Expand Up @@ -70,15 +70,15 @@ class PydanticSchemaField(JSONField, ty.Generic[types.ST]):
def __init__(
self,
*args,
schema: type[types.ST] | GenericContainer | ty.ForwardRef | str | None = None,
schema: type[types.ST] | BaseContainer | ty.ForwardRef | str | None = None,
config: pydantic.ConfigDict | None = None,
**kwargs,
):
kwargs.setdefault("encoder", DjangoJSONEncoder)
self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs)
super().__init__(*args, **kwargs)

self.schema = GenericContainer.unwrap(schema)
self.schema = BaseContainer.unwrap(schema)
self.config = config
self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs)

Expand Down Expand Up @@ -116,7 +116,7 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]:
f"Please consider using field annotation syntax, e.g. `{annot_hint} = SchemaField(...)`; "
"or a fallback to `pydantic.RootModel` with annotation instead."
)
performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.W004"))
performed_checks.append(checks.Error(message, obj=self, hint=hint, id="pydantic.E004"))

try:
# Test that the schema could be resolved in runtime, even if it contains forward references.
Expand Down
4 changes: 2 additions & 2 deletions django_pydantic_field/v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pydantic
import typing_extensions as te

from django_pydantic_field.compat.django import GenericContainer
from django_pydantic_field.compat.django import BaseContainer, GenericContainer
from django_pydantic_field.compat.functools import cached_property

from . import utils
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
allow_null: bool | None = None,
**export_kwargs: ty.Unpack[ExportKwargs],
):
self.schema = GenericContainer.unwrap(schema)
self.schema = BaseContainer.unwrap(schema)
self.config = config
self.parent_type = parent_type
self.attname = attname
Expand Down
Loading

0 comments on commit 251e641

Please sign in to comment.