diff --git a/tests/codegen/handlers/test_unnest_inner_classes.py b/tests/codegen/handlers/test_unnest_inner_classes.py index 8b7b953cc..0e216e2b3 100644 --- a/tests/codegen/handlers/test_unnest_inner_classes.py +++ b/tests/codegen/handlers/test_unnest_inner_classes.py @@ -15,124 +15,99 @@ def setUp(self): self.container = ClassContainer(config=GeneratorConfig()) self.processor = UnnestInnerClasses(container=self.container) - def test_process(self): + def test_process_with_config_enabled(self): + self.container.config.output.unnest_classes = True + a = ClassFactory.create() + b = ClassFactory.create() + c = ClassFactory.create() + + a.attrs.append(AttrFactory.reference(b.qname, forward=True)) + b.attrs.append(AttrFactory.reference(c.qname, forward=True)) + c.attrs.append(AttrFactory.reference(b.qname, forward=True)) + + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b + + self.container.add(a) + self.processor.process(a) + self.assertEqual(3, len(list(self.container))) + + self.assertEqual(b.qname, a.attrs[0].types[0].qname) + self.assertEqual(c.qname, b.attrs[0].types[0].qname) + self.assertEqual(b.qname, c.attrs[0].types[0].qname) + + def test_process_with_config_disabled_promotes_only_enumerations(self): self.container.config.output.unnest_classes = False + a = ClassFactory.create() + b = ClassFactory.create() + c = ClassFactory.enumeration(2) - enumeration = ClassFactory.enumeration(2) - local_type = ClassFactory.elements(2) - target = ClassFactory.create() + a.attrs.append(AttrFactory.reference(b.qname, forward=True)) + b.attrs.append(AttrFactory.reference(c.qname, forward=True)) - target.inner.append(enumeration) - target.inner.append(local_type) - self.container.add(target) + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b - self.processor.process(target) + self.container.add(a) + self.processor.process(a) - self.assertEqual(1, len(target.inner)) - self.assertTrue(local_type in target.inner) + self.assertEqual(2, len(list(self.container))) + self.assertEqual(c.qname, b.attrs[0].types[0].qname) + self.assertEqual(1, len(a.inner)) + self.assertEqual(0, len(b.inner)) + def test_process_with_orphan_nested_class(self): self.container.config.output.unnest_classes = True - self.processor.process(target) - self.assertEqual(0, len(target.inner)) - - def test_promote_with_orphan_inner(self): - inner = ClassFactory.elements(2) - target = ClassFactory.create() - target.inner.append(inner) - self.container.add(target) - - self.processor.promote(target, inner) - - self.assertEqual(0, len(target.inner)) - self.assertEqual(1, len(self.container.data)) - - def test_promote_updates_forward_attr_types(self): - inner = ClassFactory.elements(2) - attr = AttrFactory.reference(inner.qname, forward=True) - target = ClassFactory.create() - target.attrs.append(attr) - target.inner.append(inner) - self.container.add(target) - - self.processor.promote(target, inner) - - self.assertEqual(0, len(target.inner)) - self.assertEqual(2, len(self.container.data)) - self.assertFalse(attr.types[0].forward) - self.assertEqual("{xsdata}class_C_class_B", attr.types[0].qname) - - def test_clone_class(self): - target = ClassFactory.create(qname="{a}b") - actual = self.processor.clone_class(target, "parent") - - self.assertIsNot(target, actual) - self.assertTrue(actual.local_type) - self.assertEqual("{a}parent_b", actual.qname) - - def test_clone_class_with_circular_reference(self): - target = ClassFactory.create(qname="{a}b") - target.attrs.append( - AttrFactory.create( - name="self", - types=[AttrTypeFactory.create(qname=target.qname, circular=True)], - ) - ) - - actual = self.processor.clone_class(target, "parent") - - self.assertIsNot(target, actual) - self.assertTrue(actual.local_type) - self.assertEqual("{a}parent_b", actual.qname) - - self.assertTrue(actual.attrs[0].types[0].circular) - self.assertEqual(actual.ref, actual.attrs[0].types[0].reference) - self.assertEqual(actual.qname, actual.attrs[0].types[0].qname) + a = ClassFactory.create() + b = ClassFactory.create() + c = ClassFactory.create() + + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b + + self.container.add(a) + self.processor.process(a) + self.assertEqual(1, len(list(self.container))) + self.assertEqual(0, len(a.inner)) + self.assertEqual(0, len(b.inner)) + + def test_update_nested_class(self): + a = ClassFactory.create(qname="a") + b = ClassFactory.create(qname="b") + c = ClassFactory.create(qname="c") + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b + + self.processor.update_nested_class(c) + + self.assertEqual("b_c", c.qname) + self.assertTrue(c.local_type) + self.assertIsNone(c.parent) + self.assertEqual(0, len(b.inner)) def test_update_types(self): - attr = AttrFactory.create( - types=[ - AttrTypeFactory.create(qname="a", forward=True), - AttrTypeFactory.create(qname="a", forward=False), - AttrTypeFactory.create(qname="b", forward=False), - ] - ) - source = ClassFactory.create() - - self.processor.update_types(attr, "a", source) - - self.assertEqual(source.qname, attr.types[0].qname) - self.assertFalse(attr.types[0].forward) - self.assertEqual(source.ref, attr.types[0].reference) - - self.assertEqual("a", attr.types[1].qname) - self.assertFalse(attr.types[1].forward) - self.assertEqual("b", attr.types[2].qname) - self.assertFalse(attr.types[2].forward) - - def test_find_forward_attr(self): - target = ClassFactory.create( - attrs=[ - AttrFactory.create( - types=[ - AttrTypeFactory.create("a", forward=False), - AttrTypeFactory.create("b", forward=False), - ] - ), - AttrFactory.create( - types=[ - AttrTypeFactory.create("a", forward=True), - AttrTypeFactory.create("b", forward=True), - ] - ), - AttrFactory.create(), - ] - ) - - actual = self.processor.find_forward_attr(target, "a") - self.assertEqual(target.attrs[1], actual) - - actual = self.processor.find_forward_attr(target, "b") - self.assertEqual(target.attrs[1], actual) - - actual = self.processor.find_forward_attr(target, "c") - self.assertIsNone(actual) + types = [ + AttrTypeFactory.create(qname="a", forward=True), + AttrTypeFactory.create(qname="a", forward=False), + AttrTypeFactory.create(qname="b", forward=True), + ] + + result = self.processor.update_types(types, "a", "b") + self.assertEqual(1, result) + + self.assertEqual("b", types[0].qname) + self.assertFalse(types[0].forward) + + self.assertEqual("a", types[1].qname) + self.assertFalse(types[1].forward) + + self.assertEqual("b", types[2].qname) + self.assertTrue(types[2].forward) diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index cdf9a5880..5811e7a43 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -125,10 +125,12 @@ def test_process_class(self): self.assertEqual(Status.FINALIZED, target.inner[1].status) def test_process_classes(self): - target = ClassFactory.create( - attrs=[AttrFactory.reference("enumeration", forward=True)], - inner=[ClassFactory.enumeration(2, qname="enumeration")], - ) + target = ClassFactory.create() + inner = ClassFactory.enumeration(2, qname="enumeration") + + target.inner.append(inner) + inner.parent = target + target.attrs.append(AttrFactory.reference("enumeration", forward=True)) self.container.add(target) self.container.process_classes(Steps.FLATTEN) diff --git a/xsdata/codegen/handlers/unnest_inner_classes.py b/xsdata/codegen/handlers/unnest_inner_classes.py index 564fe71ed..c1edcb081 100644 --- a/xsdata/codegen/handlers/unnest_inner_classes.py +++ b/xsdata/codegen/handlers/unnest_inner_classes.py @@ -1,7 +1,7 @@ -from typing import Optional +from typing import Iterator, List from xsdata.codegen.mixins import RelativeHandlerInterface -from xsdata.codegen.models import Attr, Class +from xsdata.codegen.models import AttrType, Class from xsdata.utils.namespaces import build_qname @@ -22,86 +22,62 @@ def process(self, target: Class): Args: target: The target class instance to inspect """ - for inner in target.inner.copy(): - if inner.is_enumeration or self.container.config.output.unnest_classes: - self.promote(target, inner) + all_types = list(target.types()) + promote_all = self.container.config.output.unnest_classes + for nested in list(self.find_nested_classes(target)): + if not (promote_all or nested.is_enumeration): + continue - def promote(self, target: Class, inner: Class): - """Promote the inner class to root classes. - - Steps: - - Replace forward references to the inner class - - Remove inner class from target class - - Copy the class to the global class container. - - Args: - target: The target class - inner: An inner class - """ - target.inner.remove(inner) - attr = self.find_forward_attr(target, inner.qname) - if attr: - clone = self.clone_class(inner, target.name) - self.update_types(attr, inner.qname, clone) - self.container.add(clone) + cur_qname = nested.qname + self.update_nested_class(nested) + if self.update_types(all_types, cur_qname, nested.qname): + self.container.add(nested) @classmethod - def clone_class(cls, inner: Class, name: str) -> Class: - """Clone and prepare inner class for promotion. + def find_nested_classes(cls, target: Class) -> Iterator[Class]: + """Return all inner classes recursively.""" + for inner in target.inner: + yield inner + yield from cls.find_nested_classes(inner) - Clone the inner class, mark it as promoted and pref - the qualified name with the parent class name. + @classmethod + def update_nested_class(cls, target: Class): + """Prepare the nested class to be added as root. Args: - inner: The inner class to clone and prepare - name: The parent class name to use a prefix - - Returns: - The new class instance + target: The target class """ - clone = inner.clone() - clone.parent = None - clone.local_type = True - clone.qname = build_qname(inner.target_namespace, f"{name}_{inner.name}") + assert target.parent is not None + name_parts = [target.parent.name, target.name] + new_qname = build_qname(target.target_namespace, "_".join(name_parts)) - for attr in clone.attrs: - for tp in attr.types: - if tp.circular and tp.qname == inner.qname: - tp.qname = clone.qname - tp.reference = clone.ref + target.qname = new_qname - return clone + assert target.parent is not None - @classmethod - def update_types(cls, attr: Attr, search: str, source: Class): - """Update the references from an inner to a global class. - - Args: - attr: The target attr to inspect and update - search: The current inner class qname - source: The new global class qname - """ - for attr_type in attr.types: - if attr_type.qname == search and attr_type.forward: - attr_type.qname = source.qname - attr_type.reference = source.ref - attr_type.forward = False + target.parent.inner.remove(target) + target.parent = None + target.local_type = True @classmethod - def find_forward_attr(cls, target: Class, qname: str) -> Optional[Attr]: - """Find the first attr that references the given inner class qname. + def update_types(cls, types: List[AttrType], search: str, replace: str) -> int: + """Search and replace forward references. + + Return the number changes. Args: - target: The target class instance - qname: An inner class qualified name + types: The types to search and replace + search: The search qname + replace: The replacement qname Returns: - Attr: The first attr that references the given qname - None: If no such attr exists, it can happen! + The number of changed attr types """ - for attr in target.attrs: - for attr_type in attr.types: - if attr_type.forward and attr_type.qname == qname: - return attr - - return None + updated = 0 + for tp in types: + if tp.forward and tp.qname == search: + tp.qname = replace + tp.forward = False + updated += 1 + + return updated