From 00fe560b2ff26de2d2b2c71d40ac074046aa5065 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sun, 21 Apr 2024 10:38:15 +0300 Subject: [PATCH] fix: Avoid recursive error on nested group references --- .../handlers/test_flatten_attribute_groups.py | 84 ++++++------------- .../handlers/flatten_attribute_groups.py | 5 +- xsdata/codegen/mixins.py | 4 + xsdata/codegen/utils.py | 9 +- 4 files changed, 41 insertions(+), 61 deletions(-) diff --git a/tests/codegen/handlers/test_flatten_attribute_groups.py b/tests/codegen/handlers/test_flatten_attribute_groups.py index 93c46b90d..4b4b6af33 100644 --- a/tests/codegen/handlers/test_flatten_attribute_groups.py +++ b/tests/codegen/handlers/test_flatten_attribute_groups.py @@ -1,10 +1,7 @@ -from unittest import mock - from xsdata.codegen.container import ClassContainer from xsdata.codegen.exceptions import CodegenError from xsdata.codegen.handlers import FlattenAttributeGroups -from xsdata.codegen.models import Attr, Status -from xsdata.codegen.utils import ClassUtils +from xsdata.codegen.models import Status from xsdata.models.config import GeneratorConfig from xsdata.models.enums import Tag from xsdata.utils.testing import AttrFactory, ClassFactory, FactoryTestCase @@ -17,65 +14,38 @@ def setUp(self): container = ClassContainer(config=GeneratorConfig()) self.processor = FlattenAttributeGroups(container=container) - @mock.patch.object(Attr, "is_group", new_callable=mock.PropertyMock) - @mock.patch.object(FlattenAttributeGroups, "process_attribute") - def test_process(self, mock_process_attribute, mock_is_group): - mock_is_group.side_effect = [ - True, - False, - True, - True, - False, - False, + def test_process(self): + group = ClassFactory.create(qname="group", tag=Tag.GROUP) + group.attrs = [ + AttrFactory.reference(name="one", qname="inner_one", forward=True), + AttrFactory.reference(name="two", qname="inner_two", forward=True), ] - target = ClassFactory.elements(2) - - self.processor.process(target) - self.assertEqual(6, mock_is_group.call_count) - - mock_process_attribute.assert_has_calls( - [ - mock.call(target, target.attrs[0]), - mock.call(target, target.attrs[0]), - mock.call(target, target.attrs[1]), - ] + inner_one = ClassFactory.create( + qname="inner_one", + attrs=[ + AttrFactory.reference(qname="group", tag=Tag.GROUP), + ], ) - - @mock.patch.object(ClassUtils, "copy_group_attributes") - def test_process_attribute_with_group(self, mock_copy_group_attributes): - complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE) - group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) - group_attr = AttrFactory.attribute_group(name="bar") - target = ClassFactory.create() - target.attrs.append(group_attr) - - self.processor.container.add(complex_bar) - self.processor.container.add(group_bar) - self.processor.container.add(target) - - self.processor.process_attribute(target, group_attr) - mock_copy_group_attributes.assert_called_once_with( - group_bar, target, group_attr + inner_two = inner_one.clone(qname="inner_two") + inner_one.parent = group + inner_two.parent = group + group.inner.extend([inner_one, inner_two]) + target = ClassFactory.create( + attrs=[ + AttrFactory.reference(qname="group", tag=Tag.GROUP), + ] ) + self.processor.container.extend([group, target]) + self.processor.container.process() - @mock.patch.object(ClassUtils, "copy_group_attributes") - def test_process_attribute_with_attribute_group(self, mock_copy_group_attributes): - complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE) - group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) - group_attr = AttrFactory.attribute_group(name="bar") - target = ClassFactory.create() - target.attrs.append(group_attr) - - self.processor.container.add(complex_bar) - self.processor.container.add(group_bar) - self.processor.container.add(target) + self.assertEqual(["one", "two"], [x.name for x in target.attrs]) + self.assertEqual(["inner_one", "inner_two"], [x.name for x in target.inner]) - self.processor.process_attribute(target, group_attr) - mock_copy_group_attributes.assert_called_once_with( - group_bar, target, group_attr - ) + for inner in target.inner: + self.assertEqual(["one", "two"], [x.name for x in inner.attrs]) + self.assertEqual(0, len(inner.inner)) - def test_process_attribute_with_circular_reference(self): + def test_process_attribute_with_self_reference(self): group_attr = AttrFactory.attribute_group(name="bar") target = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) target.attrs.append(group_attr) diff --git a/xsdata/codegen/handlers/flatten_attribute_groups.py b/xsdata/codegen/handlers/flatten_attribute_groups.py index 703aeaf22..cfbbf213e 100644 --- a/xsdata/codegen/handlers/flatten_attribute_groups.py +++ b/xsdata/codegen/handlers/flatten_attribute_groups.py @@ -1,6 +1,6 @@ from xsdata.codegen.exceptions import CodegenError from xsdata.codegen.mixins import RelativeHandlerInterface -from xsdata.codegen.models import Attr, Class +from xsdata.codegen.models import Attr, Class, Status from xsdata.codegen.utils import ClassUtils @@ -51,4 +51,5 @@ def process_attribute(self, target: Class, attr: Attr): if source is target: ClassUtils.remove_attribute(target, attr) else: - ClassUtils.copy_group_attributes(source, target, attr) + is_circular_ref = source.status == Status.UNGROUPING + ClassUtils.copy_group_attributes(source, target, attr, is_circular_ref) diff --git a/xsdata/codegen/mixins.py b/xsdata/codegen/mixins.py index 11dcd5ea3..3f47de3d2 100644 --- a/xsdata/codegen/mixins.py +++ b/xsdata/codegen/mixins.py @@ -24,6 +24,10 @@ def __init__(self, config: GeneratorConfig): def __iter__(self) -> Iterator[Class]: """Yield an iterator for the class map values.""" + @abc.abstractmethod + def process(self): + """Run the processor and filter steps.""" + @abc.abstractmethod def find(self, qname: str, condition: Callable = return_true) -> Optional[Class]: """Find class that matches the given qualified name and condition callable. diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index c623e6bf2..c7438a862 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -115,7 +115,9 @@ def copy_attributes(cls, source: Class, target: Class, extension: Extension): index += 1 @classmethod - def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): + def copy_group_attributes( + cls, source: Class, target: Class, attr: Attr, skip_inner_classes: bool = False + ): """Copy the attrs of the source class to the target class. The attr represents a reference to the source class which is @@ -125,6 +127,8 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): source: The source class instance target: The target class instance attr: The group attr instance + skip_inner_classes: Whether the attr is circular reference, which + means we can skip copying the inner classes. """ index = target.attrs.index(attr) target.attrs.pop(index) @@ -134,7 +138,8 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): target.attrs.insert(index, clone) index += 1 - cls.copy_inner_classes(source, target, clone) + if not skip_inner_classes: + cls.copy_inner_classes(source, target, clone) @classmethod def copy_extensions(cls, source: Class, target: Class, extension: Extension):