Skip to content

Commit

Permalink
merge changes from #138 and #134
Browse files Browse the repository at this point in the history
  • Loading branch information
rnag committed Nov 9, 2024
2 parents ca0a366 + 2a0f487 commit 13a6dca
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 10 deletions.
6 changes: 6 additions & 0 deletions dataclass_wizard/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class AbstractMeta(metaclass=ABCOrAndMeta):
# apply in a recursive manner.
recursive: ClassVar[bool] = True

# True to support cyclic or self-referential dataclasses. For example,
# the type of a dataclass field in class `A` refers to `A` itself.
#
# See https://github.com/rnag/dataclass-wizard/issues/62 for more details.
recursive_classes: ClassVar[bool] = False

# True to raise an class:`UnknownJSONKey` when an unmapped JSON key is
# encountered when `from_dict` or `from_json` is called; an unknown key is
# one that does not have a known mapping to a dataclass field.
Expand Down
9 changes: 8 additions & 1 deletion dataclass_wizard/bases_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_outer_class_name, get_class_name, create_new_class,
json_field_to_dataclass_field, dataclass_field_to_json_field
)
from .constants import TAG
from .decorators import try_with_load
from .dumpers import get_dumper
from .enums import LetterCase, DateTimeTo
Expand Down Expand Up @@ -173,10 +174,13 @@ def _as_enum_safe(cls, name: str, base_type: Type[E]) -> Optional[E]:
# noinspection PyPep8Naming
def LoadMeta(*, debug_enabled: bool = False,
recursive: bool = True,
recursive_classes: bool = False,
raise_on_unknown_json_key: bool = False,
json_key_to_field: Dict[str, str] = None,
key_transform: Union[LetterCase, str] = None,
tag: str = None) -> META:
tag: str = None,
tag_key: str = TAG,
auto_assign_tags: bool = False) -> META:
"""
Helper function to setup the ``Meta`` Config for the JSON load
(de-serialization) process, which is intended for use alongside the
Expand All @@ -198,11 +202,14 @@ def LoadMeta(*, debug_enabled: bool = False,
base_dict = {
'__slots__': (),
'raise_on_unknown_json_key': raise_on_unknown_json_key,
'recursive_classes': recursive_classes,
'key_transform_with_load': key_transform,
'json_key_to_field': json_key_to_field,
'debug_enabled': debug_enabled,
'recursive': recursive,
'tag': tag,
'tag_key': tag_key,
'auto_assign_tags': auto_assign_tags,
}

# Create a new subclass of :class:`AbstractMeta`
Expand Down
33 changes: 33 additions & 0 deletions dataclass_wizard/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,36 @@ def message(self) -> str:
msg = f'{msg}{sep}{parts}'

return msg


class RecursiveClassError(JSONWizardError):
"""
Error raised when we encounter a `RecursionError` due to cyclic
or self-referential dataclasses.
"""

_TEMPLATE = ('Failure parsing class `{cls}`. '
'Consider updating the Meta config to enable '
'the `recursive_classes` flag.\n\n'
'Example with `dataclass_wizard.LoadMeta`:\n'
' >>> LoadMeta(recursive_classes=True).bind_to({cls})\n\n'
'For more info, please see:\n'
' https://github.com/rnag/dataclass-wizard/issues/62')

def __init__(self, cls: Type):
super().__init__()

self.class_name: str = self.name(cls)

@staticmethod
def name(obj) -> str:
"""Return the type or class name of an object"""
# Uses short-circuiting with `or` to efficiently
# return the first valid name.
return (getattr(obj, '__qualname__', None)
or getattr(obj, '__name__', None)
or str(obj))

@property
def message(self) -> str:
return self._TEMPLATE.format(cls=self.class_name)
34 changes: 26 additions & 8 deletions dataclass_wizard/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
)
from .constants import _LOAD_HOOKS, SINGLE_ARG_ALIAS, IDENTITY
from .decorators import _alias, _single_arg_alias, resolve_alias_func, _identity
from .errors import ParseError, MissingFields, UnknownJSONKey, MissingData
from .errors import (ParseError, MissingFields, UnknownJSONKey,
MissingData, RecursiveClassError)
from .log import LOG
from .models import Extras, _PatternedDT
from .parsers import *
Expand Down Expand Up @@ -290,12 +291,22 @@ def get_parser_for_annotation(cls, ann_type: Type[T],
elif isinstance(base_type, type):

if is_dataclass(base_type):
base_type: Type[T]
load_hook = load_func_for_dataclass(
base_type,
is_main_class=False,
config=extras['config']
)
config: META = extras.get('config')

# enable support for cyclic / self-referential dataclasses
# see https://github.com/rnag/dataclass-wizard/issues/62
if config and config.recursive_classes:
# noinspection PyTypeChecker
return RecursionSafeParser(
base_cls, extras, base_type, hook=None
)
else: # else, logic is same as normal
base_type: Type[T]
load_hook = load_func_for_dataclass(
base_type,
is_main_class=False,
config=extras['config']
)

elif issubclass(base_type, Enum):
load_hook = hooks.get(Enum)
Expand Down Expand Up @@ -593,7 +604,14 @@ def load_func_for_dataclass(

# This contains a mapping of the original field name to the parser for its
# annotated type; the item lookup *can* be case-insensitive.
field_to_parser = dataclass_field_to_load_parser(cls_loader, cls, config)
try:
field_to_parser = dataclass_field_to_load_parser(cls_loader, cls, config)
except RecursionError as e:
if meta.recursive_classes:
# recursion-safe loader is already in use; something else must have gone wrong
raise
else:
raise RecursiveClassError(cls) from None

# A cached mapping of each key in a JSON or dictionary object to the
# resolved dataclass field name; useful so we don't need to do a case
Expand Down
42 changes: 41 additions & 1 deletion dataclass_wizard/parsers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ['IdentityParser',
'SingleArgParser',
'Parser',
'RecursionSafeParser',
'PatternedDTParser',
'LiteralParser',
'UnionParser',
Expand Down Expand Up @@ -36,6 +37,7 @@

# Type defs
GetParserType = Callable[[Type[T], Type, Extras], AbstractParser]
LoadHookType = Callable[[Any], T]
TupleOfParsers = Tuple[AbstractParser, ...]


Expand All @@ -51,7 +53,7 @@ def __call__(self, o: Any) -> T:
class SingleArgParser(AbstractParser[Type[T], T]):
__slots__ = ('hook', )

hook: Callable[[Any], T]
hook: LoadHookType

# noinspection PyDataclass
def __post_init__(self, *_):
Expand All @@ -72,6 +74,44 @@ def __call__(self, o: Any) -> T:
return self.hook(o, self.base_type)


@dataclass
class RecursionSafeParser(AbstractParser):
"""
Parser to handle cyclic or self-referential dataclasses.
For example::
@dataclass
class A:
a: A | None = None
instance = fromdict(A, {'a': {'a': {'a': None}}})
"""
__slots__ = ('extras', 'hook')

extras: Extras
hook: Optional[LoadHookType]

def load_hook_func(self) -> LoadHookType:
from .loaders import load_func_for_dataclass

return load_func_for_dataclass(
self.base_type,
is_main_class=False,
config=self.extras['config']
)

# TODO: decorating `load_hook_func` with `@cached_property` could
# be an alternate, bit cleaner approach.
def __call__(self, o: Any) -> T:
load_hook = self.hook

if not load_hook:
load_hook = self.hook = self.load_hook_func()

return load_hook(o)


@dataclass
class LiteralParser(AbstractParser[Type[M], M]):
__slots__ = ('value_to_type', )
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,3 +1828,50 @@ class Item(JSONSerializable):

assert item.a == {}
assert item.b is item.c is None


def test_with_self_referential_dataclasses_1():
"""
Test loading JSON data, when a dataclass model has cyclic
or self-referential dataclasses. For example, A -> A -> A.
"""
@dataclass
class A:
a: Optional['A'] = None

# enable support for self-referential / recursive dataclasses
LoadMeta(recursive_classes=True).bind_to(A)

# Fix for local test cases so the forward reference works
globals().update(locals())

# assert that `fromdict` with a recursive, self-referential
# input `dict` works as expected.
a = fromdict(A, {'a': {'a': {'a': None}}})
assert a == A(a=A(a=A(a=None)))


def test_with_self_referential_dataclasses_2():
"""
Test loading JSON data, when a dataclass model has cyclic
or self-referential dataclasses. For example, A -> B -> A -> B.
"""
@dataclass
class A(JSONWizard):
class _(JSONWizard.Meta):
# enable support for self-referential / recursive dataclasses
recursive_classes = True

b: Optional['B'] = None

@dataclass
class B:
a: Optional['A'] = None

# Fix for local test cases so the forward reference works
globals().update(locals())

# assert that `fromdict` with a recursive, self-referential
# input `dict` works as expected.
a = fromdict(A, {'b': {'a': {'b': {'a': None}}}})
assert a == A(b=B(a=A(b=B())))

0 comments on commit 13a6dca

Please sign in to comment.