Skip to content

Commit

Permalink
fix: Avoid recursive error on nested group references
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Apr 21, 2024
1 parent 7cacca5 commit 00fe560
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 61 deletions.
84 changes: 27 additions & 57 deletions tests/codegen/handlers/test_flatten_attribute_groups.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from unittest import mock

from xsdata.codegen.container import ClassContainer
from xsdata.codegen.exceptions import CodegenError
from xsdata.codegen.handlers import FlattenAttributeGroups
from xsdata.codegen.models import Attr, Status
from xsdata.codegen.utils import ClassUtils
from xsdata.codegen.models import Status
from xsdata.models.config import GeneratorConfig
from xsdata.models.enums import Tag
from xsdata.utils.testing import AttrFactory, ClassFactory, FactoryTestCase
Expand All @@ -17,65 +14,38 @@ def setUp(self):
container = ClassContainer(config=GeneratorConfig())
self.processor = FlattenAttributeGroups(container=container)

@mock.patch.object(Attr, "is_group", new_callable=mock.PropertyMock)
@mock.patch.object(FlattenAttributeGroups, "process_attribute")
def test_process(self, mock_process_attribute, mock_is_group):
mock_is_group.side_effect = [
True,
False,
True,
True,
False,
False,
def test_process(self):
group = ClassFactory.create(qname="group", tag=Tag.GROUP)
group.attrs = [
AttrFactory.reference(name="one", qname="inner_one", forward=True),
AttrFactory.reference(name="two", qname="inner_two", forward=True),
]
target = ClassFactory.elements(2)

self.processor.process(target)
self.assertEqual(6, mock_is_group.call_count)

mock_process_attribute.assert_has_calls(
[
mock.call(target, target.attrs[0]),
mock.call(target, target.attrs[0]),
mock.call(target, target.attrs[1]),
]
inner_one = ClassFactory.create(
qname="inner_one",
attrs=[
AttrFactory.reference(qname="group", tag=Tag.GROUP),
],
)

@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_process_attribute_with_group(self, mock_copy_group_attributes):
complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE)
group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create()
target.attrs.append(group_attr)

self.processor.container.add(complex_bar)
self.processor.container.add(group_bar)
self.processor.container.add(target)

self.processor.process_attribute(target, group_attr)
mock_copy_group_attributes.assert_called_once_with(
group_bar, target, group_attr
inner_two = inner_one.clone(qname="inner_two")
inner_one.parent = group
inner_two.parent = group
group.inner.extend([inner_one, inner_two])
target = ClassFactory.create(
attrs=[
AttrFactory.reference(qname="group", tag=Tag.GROUP),
]
)
self.processor.container.extend([group, target])
self.processor.container.process()

@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_process_attribute_with_attribute_group(self, mock_copy_group_attributes):
complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE)
group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create()
target.attrs.append(group_attr)

self.processor.container.add(complex_bar)
self.processor.container.add(group_bar)
self.processor.container.add(target)
self.assertEqual(["one", "two"], [x.name for x in target.attrs])
self.assertEqual(["inner_one", "inner_two"], [x.name for x in target.inner])

self.processor.process_attribute(target, group_attr)
mock_copy_group_attributes.assert_called_once_with(
group_bar, target, group_attr
)
for inner in target.inner:
self.assertEqual(["one", "two"], [x.name for x in inner.attrs])
self.assertEqual(0, len(inner.inner))

def test_process_attribute_with_circular_reference(self):
def test_process_attribute_with_self_reference(self):
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
target.attrs.append(group_attr)
Expand Down
5 changes: 3 additions & 2 deletions xsdata/codegen/handlers/flatten_attribute_groups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from xsdata.codegen.exceptions import CodegenError
from xsdata.codegen.mixins import RelativeHandlerInterface
from xsdata.codegen.models import Attr, Class
from xsdata.codegen.models import Attr, Class, Status
from xsdata.codegen.utils import ClassUtils


Expand Down Expand Up @@ -51,4 +51,5 @@ def process_attribute(self, target: Class, attr: Attr):
if source is target:
ClassUtils.remove_attribute(target, attr)
else:
ClassUtils.copy_group_attributes(source, target, attr)
is_circular_ref = source.status == Status.UNGROUPING
ClassUtils.copy_group_attributes(source, target, attr, is_circular_ref)
4 changes: 4 additions & 0 deletions xsdata/codegen/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __init__(self, config: GeneratorConfig):
def __iter__(self) -> Iterator[Class]:
"""Yield an iterator for the class map values."""

@abc.abstractmethod
def process(self):
"""Run the processor and filter steps."""

@abc.abstractmethod
def find(self, qname: str, condition: Callable = return_true) -> Optional[Class]:
"""Find class that matches the given qualified name and condition callable.
Expand Down
9 changes: 7 additions & 2 deletions xsdata/codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def copy_attributes(cls, source: Class, target: Class, extension: Extension):
index += 1

@classmethod
def copy_group_attributes(cls, source: Class, target: Class, attr: Attr):
def copy_group_attributes(
cls, source: Class, target: Class, attr: Attr, skip_inner_classes: bool = False
):
"""Copy the attrs of the source class to the target class.
The attr represents a reference to the source class which is
Expand All @@ -125,6 +127,8 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr):
source: The source class instance
target: The target class instance
attr: The group attr instance
skip_inner_classes: Whether the attr is circular reference, which
means we can skip copying the inner classes.
"""
index = target.attrs.index(attr)
target.attrs.pop(index)
Expand All @@ -134,7 +138,8 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr):
target.attrs.insert(index, clone)
index += 1

cls.copy_inner_classes(source, target, clone)
if not skip_inner_classes:
cls.copy_inner_classes(source, target, clone)

@classmethod
def copy_extensions(cls, source: Class, target: Class, extension: Extension):
Expand Down

0 comments on commit 00fe560

Please sign in to comment.