From 2a0f4873924d5b6eb8db9decb52ac1931dcca75a Mon Sep 17 00:00:00 2001 From: Dan Lenski Date: Sat, 9 Nov 2024 12:44:28 -0800 Subject: [PATCH] Support cyclic references (#138) * support cyclic or self-referential dataclasses * update to add custom recursion error * Don't replace RecursionError with RecursiveClassError if recursion-safe loader is already enabled If we get a `RecursionError` in this case, something else must have gone wrong --------- Co-authored-by: Ritvik Nag --- dataclass_wizard/bases.py | 6 +++++ dataclass_wizard/bases_meta.py | 9 ++++++- dataclass_wizard/errors.py | 30 ++++++++++++++++++++++ dataclass_wizard/loaders.py | 34 ++++++++++++++++++------ dataclass_wizard/parsers.py | 42 +++++++++++++++++++++++++++++- tests/unit/test_load.py | 47 ++++++++++++++++++++++++++++++++++ 6 files changed, 158 insertions(+), 10 deletions(-) diff --git a/dataclass_wizard/bases.py b/dataclass_wizard/bases.py index b151d855..9c0d5ff7 100644 --- a/dataclass_wizard/bases.py +++ b/dataclass_wizard/bases.py @@ -134,6 +134,12 @@ class AbstractMeta(metaclass=ABCOrAndMeta): # apply in a recursive manner. recursive: ClassVar[bool] = True + # True to support cyclic or self-referential dataclasses. For example, + # the type of a dataclass field in class `A` refers to `A` itself. + # + # See https://github.com/rnag/dataclass-wizard/issues/62 for more details. + recursive_classes: ClassVar[bool] = False + # True to raise an class:`UnknownJSONKey` when an unmapped JSON key is # encountered when `from_dict` or `from_json` is called; an unknown key is # one that does not have a known mapping to a dataclass field. diff --git a/dataclass_wizard/bases_meta.py b/dataclass_wizard/bases_meta.py index 52352e6c..99ac5beb 100644 --- a/dataclass_wizard/bases_meta.py +++ b/dataclass_wizard/bases_meta.py @@ -14,6 +14,7 @@ get_outer_class_name, get_class_name, create_new_class, json_field_to_dataclass_field, dataclass_field_to_json_field ) +from .constants import TAG from .decorators import try_with_load from .dumpers import get_dumper from .enums import LetterCase, DateTimeTo @@ -173,10 +174,13 @@ def _as_enum_safe(cls, name: str, base_type: Type[E]) -> Optional[E]: # noinspection PyPep8Naming def LoadMeta(*, debug_enabled: bool = False, recursive: bool = True, + recursive_classes: bool = False, raise_on_unknown_json_key: bool = False, json_key_to_field: Dict[str, str] = None, key_transform: Union[LetterCase, str] = None, - tag: str = None) -> META: + tag: str = None, + tag_key: str = TAG, + auto_assign_tags: bool = False) -> META: """ Helper function to setup the ``Meta`` Config for the JSON load (de-serialization) process, which is intended for use alongside the @@ -198,11 +202,14 @@ def LoadMeta(*, debug_enabled: bool = False, base_dict = { '__slots__': (), 'raise_on_unknown_json_key': raise_on_unknown_json_key, + 'recursive_classes': recursive_classes, 'key_transform_with_load': key_transform, 'json_key_to_field': json_key_to_field, 'debug_enabled': debug_enabled, 'recursive': recursive, 'tag': tag, + 'tag_key': tag_key, + 'auto_assign_tags': auto_assign_tags, } # Create a new subclass of :class:`AbstractMeta` diff --git a/dataclass_wizard/errors.py b/dataclass_wizard/errors.py index ba2b39f2..f1a924ff 100644 --- a/dataclass_wizard/errors.py +++ b/dataclass_wizard/errors.py @@ -266,3 +266,33 @@ def message(self) -> str: msg = f'{msg}{sep}{parts}' return msg + + +class RecursiveClassError(JSONWizardError): + """ + Error raised when we encounter a `RecursionError` due to cyclic + or self-referential dataclasses. + """ + + _TEMPLATE = ('Failure parsing class `{cls}`. ' + 'Consider updating the Meta config to enable ' + 'the `recursive_classes` flag.\n\n' + 'Example with `dataclass_wizard.LoadMeta`:\n' + ' >>> LoadMeta(recursive_classes=True).bind_to({cls})\n\n' + 'For more info, please see:\n' + ' https://github.com/rnag/dataclass-wizard/issues/62') + + def __init__(self, cls: Type): + super().__init__() + + self.class_name: str = self.name(cls) + + # TODO: update to use `type_name` once changes are merged + @staticmethod + def name(obj) -> str: + """Return the type or class name of an object""" + return getattr(obj, '__qualname__', getattr(obj, '__name__', obj)) + + @property + def message(self) -> str: + return self._TEMPLATE.format(cls=self.class_name) diff --git a/dataclass_wizard/loaders.py b/dataclass_wizard/loaders.py index 3407d45f..fc616b41 100644 --- a/dataclass_wizard/loaders.py +++ b/dataclass_wizard/loaders.py @@ -20,7 +20,8 @@ ) from .constants import _LOAD_HOOKS, SINGLE_ARG_ALIAS, IDENTITY from .decorators import _alias, _single_arg_alias, resolve_alias_func, _identity -from .errors import ParseError, MissingFields, UnknownJSONKey, MissingData +from .errors import (ParseError, MissingFields, UnknownJSONKey, + MissingData, RecursiveClassError) from .log import LOG from .models import Extras, _PatternedDT from .parsers import * @@ -290,12 +291,22 @@ def get_parser_for_annotation(cls, ann_type: Type[T], elif isinstance(base_type, type): if is_dataclass(base_type): - base_type: Type[T] - load_hook = load_func_for_dataclass( - base_type, - is_main_class=False, - config=extras['config'] - ) + config: META = extras.get('config') + + # enable support for cyclic / self-referential dataclasses + # see https://github.com/rnag/dataclass-wizard/issues/62 + if config and config.recursive_classes: + # noinspection PyTypeChecker + return RecursionSafeParser( + base_cls, extras, base_type, hook=None + ) + else: # else, logic is same as normal + base_type: Type[T] + load_hook = load_func_for_dataclass( + base_type, + is_main_class=False, + config=extras['config'] + ) elif issubclass(base_type, Enum): load_hook = hooks.get(Enum) @@ -593,7 +604,14 @@ def load_func_for_dataclass( # This contains a mapping of the original field name to the parser for its # annotated type; the item lookup *can* be case-insensitive. - field_to_parser = dataclass_field_to_load_parser(cls_loader, cls, config) + try: + field_to_parser = dataclass_field_to_load_parser(cls_loader, cls, config) + except RecursionError as e: + if meta.recursive_classes: + # recursion-safe loader is already in use; something else must have gone wrong + raise + else: + raise RecursiveClassError(cls) from None # A cached mapping of each key in a JSON or dictionary object to the # resolved dataclass field name; useful so we don't need to do a case diff --git a/dataclass_wizard/parsers.py b/dataclass_wizard/parsers.py index 23cbb599..aa3fd1fd 100644 --- a/dataclass_wizard/parsers.py +++ b/dataclass_wizard/parsers.py @@ -1,6 +1,7 @@ __all__ = ['IdentityParser', 'SingleArgParser', 'Parser', + 'RecursionSafeParser', 'PatternedDTParser', 'LiteralParser', 'UnionParser', @@ -36,6 +37,7 @@ # Type defs GetParserType = Callable[[Type[T], Type, Extras], AbstractParser] +LoadHookType = Callable[[Any], T] TupleOfParsers = Tuple[AbstractParser, ...] @@ -51,7 +53,7 @@ def __call__(self, o: Any) -> T: class SingleArgParser(AbstractParser[Type[T], T]): __slots__ = ('hook', ) - hook: Callable[[Any], T] + hook: LoadHookType # noinspection PyDataclass def __post_init__(self, *_): @@ -72,6 +74,44 @@ def __call__(self, o: Any) -> T: return self.hook(o, self.base_type) +@dataclass +class RecursionSafeParser(AbstractParser): + """ + Parser to handle cyclic or self-referential dataclasses. + + For example:: + + @dataclass + class A: + a: A | None = None + + instance = fromdict(A, {'a': {'a': {'a': None}}}) + """ + __slots__ = ('extras', 'hook') + + extras: Extras + hook: Optional[LoadHookType] + + def load_hook_func(self) -> LoadHookType: + from .loaders import load_func_for_dataclass + + return load_func_for_dataclass( + self.base_type, + is_main_class=False, + config=self.extras['config'] + ) + + # TODO: decorating `load_hook_func` with `@cached_property` could + # be an alternate, bit cleaner approach. + def __call__(self, o: Any) -> T: + load_hook = self.hook + + if not load_hook: + load_hook = self.hook = self.load_hook_func() + + return load_hook(o) + + @dataclass class LiteralParser(AbstractParser[Type[M], M]): __slots__ = ('value_to_type', ) diff --git a/tests/unit/test_load.py b/tests/unit/test_load.py index 9703fd5b..07ae18b1 100644 --- a/tests/unit/test_load.py +++ b/tests/unit/test_load.py @@ -1801,3 +1801,50 @@ class Outer(JSONWizard): # the error should mention that we want a dict, but get a list assert e.ann_type == dict assert e.obj_type == list + + +def test_with_self_referential_dataclasses_1(): + """ + Test loading JSON data, when a dataclass model has cyclic + or self-referential dataclasses. For example, A -> A -> A. + """ + @dataclass + class A: + a: Optional['A'] = None + + # enable support for self-referential / recursive dataclasses + LoadMeta(recursive_classes=True).bind_to(A) + + # Fix for local test cases so the forward reference works + globals().update(locals()) + + # assert that `fromdict` with a recursive, self-referential + # input `dict` works as expected. + a = fromdict(A, {'a': {'a': {'a': None}}}) + assert a == A(a=A(a=A(a=None))) + + +def test_with_self_referential_dataclasses_2(): + """ + Test loading JSON data, when a dataclass model has cyclic + or self-referential dataclasses. For example, A -> B -> A -> B. + """ + @dataclass + class A(JSONWizard): + class _(JSONWizard.Meta): + # enable support for self-referential / recursive dataclasses + recursive_classes = True + + b: Optional['B'] = None + + @dataclass + class B: + a: Optional['A'] = None + + # Fix for local test cases so the forward reference works + globals().update(locals()) + + # assert that `fromdict` with a recursive, self-referential + # input `dict` works as expected. + a = fromdict(A, {'b': {'a': {'b': {'a': None}}}}) + assert a == A(b=B(a=A(b=B())))