Skip to content

Commit

Permalink
Merge pull request #1016 from tefra/improve-nested-class-dependencies
Browse files Browse the repository at this point in the history
fix: Avoid recursive error on nested group references
  • Loading branch information
tefra authored Apr 21, 2024
2 parents 762ecb8 + 00fe560 commit 7628204
Show file tree
Hide file tree
Showing 20 changed files with 338 additions and 236 deletions.
84 changes: 27 additions & 57 deletions tests/codegen/handlers/test_flatten_attribute_groups.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from unittest import mock

from xsdata.codegen.container import ClassContainer
from xsdata.codegen.exceptions import CodegenError
from xsdata.codegen.handlers import FlattenAttributeGroups
from xsdata.codegen.models import Attr, Status
from xsdata.codegen.utils import ClassUtils
from xsdata.codegen.models import Status
from xsdata.models.config import GeneratorConfig
from xsdata.models.enums import Tag
from xsdata.utils.testing import AttrFactory, ClassFactory, FactoryTestCase
Expand All @@ -17,65 +14,38 @@ def setUp(self):
container = ClassContainer(config=GeneratorConfig())
self.processor = FlattenAttributeGroups(container=container)

@mock.patch.object(Attr, "is_group", new_callable=mock.PropertyMock)
@mock.patch.object(FlattenAttributeGroups, "process_attribute")
def test_process(self, mock_process_attribute, mock_is_group):
mock_is_group.side_effect = [
True,
False,
True,
True,
False,
False,
def test_process(self):
group = ClassFactory.create(qname="group", tag=Tag.GROUP)
group.attrs = [
AttrFactory.reference(name="one", qname="inner_one", forward=True),
AttrFactory.reference(name="two", qname="inner_two", forward=True),
]
target = ClassFactory.elements(2)

self.processor.process(target)
self.assertEqual(6, mock_is_group.call_count)

mock_process_attribute.assert_has_calls(
[
mock.call(target, target.attrs[0]),
mock.call(target, target.attrs[0]),
mock.call(target, target.attrs[1]),
]
inner_one = ClassFactory.create(
qname="inner_one",
attrs=[
AttrFactory.reference(qname="group", tag=Tag.GROUP),
],
)

@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_process_attribute_with_group(self, mock_copy_group_attributes):
complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE)
group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create()
target.attrs.append(group_attr)

self.processor.container.add(complex_bar)
self.processor.container.add(group_bar)
self.processor.container.add(target)

self.processor.process_attribute(target, group_attr)
mock_copy_group_attributes.assert_called_once_with(
group_bar, target, group_attr
inner_two = inner_one.clone(qname="inner_two")
inner_one.parent = group
inner_two.parent = group
group.inner.extend([inner_one, inner_two])
target = ClassFactory.create(
attrs=[
AttrFactory.reference(qname="group", tag=Tag.GROUP),
]
)
self.processor.container.extend([group, target])
self.processor.container.process()

@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_process_attribute_with_attribute_group(self, mock_copy_group_attributes):
complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE)
group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create()
target.attrs.append(group_attr)

self.processor.container.add(complex_bar)
self.processor.container.add(group_bar)
self.processor.container.add(target)
self.assertEqual(["one", "two"], [x.name for x in target.attrs])
self.assertEqual(["inner_one", "inner_two"], [x.name for x in target.inner])

self.processor.process_attribute(target, group_attr)
mock_copy_group_attributes.assert_called_once_with(
group_bar, target, group_attr
)
for inner in target.inner:
self.assertEqual(["one", "two"], [x.name for x in inner.attrs])
self.assertEqual(0, len(inner.inner))

def test_process_attribute_with_circular_reference(self):
def test_process_attribute_with_self_reference(self):
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP)
target.attrs.append(group_attr)
Expand Down
34 changes: 34 additions & 0 deletions tests/codegen/handlers/test_validate_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,37 @@ def test_validate_misrepresented_references(self):

with self.assertRaises(CodegenError):
self.handler.run()

def test_validate_parent_references_with_root_class_with_parent(self):
target = ClassFactory.create()
target.parent = ClassFactory.create()
self.container.add(target)

with self.assertRaises(CodegenError):
self.handler.run()

def test_validate_parent_references_with_wrong_parent(self):
parent = ClassFactory.create()
child = ClassFactory.create()
wrong = ClassFactory.create()

parent.inner.append(child)
child.parent = wrong

self.container.extend([parent, wrong])

with self.assertRaises(CodegenError):
self.handler.run()

def test_validate_parent_references_with_wrong_parent_ref(self):
parent = ClassFactory.create()
child = ClassFactory.create()
wrong = parent.clone()

parent.inner.append(child)
child.parent = wrong

self.container.extend([parent])

with self.assertRaises(CodegenError):
self.handler.run()
3 changes: 3 additions & 0 deletions tests/codegen/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def test_process_class(self):
target = ClassFactory.create(
inner=[ClassFactory.elements(2), ClassFactory.elements(1)]
)
for inner in target.inner:
inner.parent = target

self.container.add(target)

self.container.process()
Expand Down
39 changes: 25 additions & 14 deletions tests/codegen/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,6 @@ def test_copy_inner_class_with_missing_inner(self):
with self.assertRaises(CodegenError):
ClassUtils.copy_inner_class(source, target, attr_type)

def test_find_inner(self):
obj = ClassFactory.create(qname="{a}parent")
first = ClassFactory.create(qname="{a}a")
second = ClassFactory.create(qname="{c}c")
third = ClassFactory.enumeration(2, qname="{d}d")
obj.inner.extend((first, second, third))

with self.assertRaises(CodegenError):
self.assertIsNone(ClassUtils.find_inner(obj, "nope"))

self.assertEqual(first, ClassUtils.find_inner(obj, "{a}a"))
self.assertEqual(second, ClassUtils.find_inner(obj, "{c}c"))
self.assertEqual(third, ClassUtils.find_inner(obj, "{d}d"))

def test_flatten(self):
target = ClassFactory.create(
qname="{xsdata}root", attrs=AttrFactory.list(3), inner=ClassFactory.list(2)
Expand Down Expand Up @@ -400,3 +386,28 @@ def test_filter_types(self):
types = [xs_any]
actual = ClassUtils.filter_types(types)
self.assertEqual(1, len(actual))

def test_find_nested(self):
a = ClassFactory.create(qname="a")
b = ClassFactory.create(qname="b")
c = ClassFactory.create(qname="c")

a.inner.append(b)
b.inner.append(c)
c.parent = b
b.parent = a

self.assertEqual(a, ClassUtils.find_nested(a, "a"))
self.assertEqual(b, ClassUtils.find_nested(a, "b"))
self.assertEqual(b, ClassUtils.find_nested(c, "b"))
self.assertEqual(a, ClassUtils.find_nested(c, "a"))

a2 = ClassFactory.create(qname="a")
c.inner.append(a2)
a2.parent = c

# Breadth-first search
self.assertEqual(a2, ClassUtils.find_nested(c, "a"))

with self.assertRaises(CodegenError):
ClassUtils.find_nested(a, "nope")
Loading

0 comments on commit 7628204

Please sign in to comment.