diff --git a/tests/codegen/handlers/test_unnest_inner_classes.py b/tests/codegen/handlers/test_unnest_inner_classes.py index 8b7b953cc..3dd90e10e 100644 --- a/tests/codegen/handlers/test_unnest_inner_classes.py +++ b/tests/codegen/handlers/test_unnest_inner_classes.py @@ -15,124 +15,94 @@ def setUp(self): self.container = ClassContainer(config=GeneratorConfig()) self.processor = UnnestInnerClasses(container=self.container) - def test_process(self): + def test_process_with_config_enabled(self): + self.container.config.output.unnest_classes = True + a = ClassFactory.create() + b = ClassFactory.create() + c = ClassFactory.create() + + a.attrs.append(AttrFactory.reference(b.qname, forward=True)) + b.attrs.append(AttrFactory.reference(c.qname, forward=True)) + c.attrs.append(AttrFactory.reference(b.qname, forward=True)) + + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b + + self.container.add(a) + self.processor.process(a) + self.assertEqual(3, len(list(self.container))) + + self.assertEqual(b.qname, a.attrs[0].types[0].qname) + self.assertEqual(c.qname, b.attrs[0].types[0].qname) + self.assertEqual(b.qname, c.attrs[0].types[0].qname) + + def test_process_with_config_disabled_promotes_only_enumerations(self): self.container.config.output.unnest_classes = False + a = ClassFactory.create() + b = ClassFactory.create() + c = ClassFactory.enumeration(2) - enumeration = ClassFactory.enumeration(2) - local_type = ClassFactory.elements(2) - target = ClassFactory.create() + a.attrs.append(AttrFactory.reference(b.qname, forward=True)) + b.attrs.append(AttrFactory.reference(c.qname, forward=True)) - target.inner.append(enumeration) - target.inner.append(local_type) - self.container.add(target) + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b - self.processor.process(target) + self.container.add(a) + self.processor.process(a) - self.assertEqual(1, len(target.inner)) - self.assertTrue(local_type in target.inner) + self.assertEqual(2, len(list(self.container))) + self.assertEqual(c.qname, b.attrs[0].types[0].qname) + self.assertEqual(1, len(a.inner)) + self.assertEqual(0, len(b.inner)) + def test_process_with_orphan_nested_class(self): self.container.config.output.unnest_classes = True - self.processor.process(target) - self.assertEqual(0, len(target.inner)) - - def test_promote_with_orphan_inner(self): - inner = ClassFactory.elements(2) - target = ClassFactory.create() - target.inner.append(inner) - self.container.add(target) - - self.processor.promote(target, inner) - - self.assertEqual(0, len(target.inner)) - self.assertEqual(1, len(self.container.data)) - - def test_promote_updates_forward_attr_types(self): - inner = ClassFactory.elements(2) - attr = AttrFactory.reference(inner.qname, forward=True) - target = ClassFactory.create() - target.attrs.append(attr) - target.inner.append(inner) - self.container.add(target) - - self.processor.promote(target, inner) - - self.assertEqual(0, len(target.inner)) - self.assertEqual(2, len(self.container.data)) - self.assertFalse(attr.types[0].forward) - self.assertEqual("{xsdata}class_C_class_B", attr.types[0].qname) - - def test_clone_class(self): - target = ClassFactory.create(qname="{a}b") - actual = self.processor.clone_class(target, "parent") - - self.assertIsNot(target, actual) - self.assertTrue(actual.local_type) - self.assertEqual("{a}parent_b", actual.qname) - - def test_clone_class_with_circular_reference(self): - target = ClassFactory.create(qname="{a}b") - target.attrs.append( - AttrFactory.create( - name="self", - types=[AttrTypeFactory.create(qname=target.qname, circular=True)], - ) - ) - - actual = self.processor.clone_class(target, "parent") - - self.assertIsNot(target, actual) - self.assertTrue(actual.local_type) - self.assertEqual("{a}parent_b", actual.qname) - - self.assertTrue(actual.attrs[0].types[0].circular) - self.assertEqual(actual.ref, actual.attrs[0].types[0].reference) - self.assertEqual(actual.qname, actual.attrs[0].types[0].qname) + a = ClassFactory.create() + b = ClassFactory.create() + c = ClassFactory.create() + + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b + + self.container.add(a) + self.processor.process(a) + self.assertEqual(1, len(list(self.container))) + self.assertEqual(0, len(a.inner)) + self.assertEqual(1, len(b.inner)) + + def test_update_inner_class(self): + a = ClassFactory.create(qname="a") + b = ClassFactory.create(qname="b") + c = ClassFactory.create(qname="c") + a.inner.append(b) + b.inner.append(c) + b.parent = a + c.parent = b + + self.processor.update_inner_class(c) + + self.assertEqual("b_c", c.qname) + self.assertTrue(c.local_type) + self.assertIsNone(c.parent) + self.assertEqual(0, len(b.inner)) def test_update_types(self): - attr = AttrFactory.create( - types=[ - AttrTypeFactory.create(qname="a", forward=True), - AttrTypeFactory.create(qname="a", forward=False), - AttrTypeFactory.create(qname="b", forward=False), - ] - ) - source = ClassFactory.create() - - self.processor.update_types(attr, "a", source) - - self.assertEqual(source.qname, attr.types[0].qname) - self.assertFalse(attr.types[0].forward) - self.assertEqual(source.ref, attr.types[0].reference) - - self.assertEqual("a", attr.types[1].qname) - self.assertFalse(attr.types[1].forward) - self.assertEqual("b", attr.types[2].qname) - self.assertFalse(attr.types[2].forward) - - def test_find_forward_attr(self): - target = ClassFactory.create( - attrs=[ - AttrFactory.create( - types=[ - AttrTypeFactory.create("a", forward=False), - AttrTypeFactory.create("b", forward=False), - ] - ), - AttrFactory.create( - types=[ - AttrTypeFactory.create("a", forward=True), - AttrTypeFactory.create("b", forward=True), - ] - ), - AttrFactory.create(), - ] - ) - - actual = self.processor.find_forward_attr(target, "a") - self.assertEqual(target.attrs[1], actual) - - actual = self.processor.find_forward_attr(target, "b") - self.assertEqual(target.attrs[1], actual) - - actual = self.processor.find_forward_attr(target, "c") - self.assertIsNone(actual) + types = [ + AttrTypeFactory.create(qname="a", forward=True), + AttrTypeFactory.create(qname="a", forward=True), + ] + + self.processor.update_types(types, "b") + + self.assertEqual("b", types[0].qname) + self.assertFalse(types[0].forward) + + self.assertEqual("b", types[1].qname) + self.assertFalse(types[1].forward) diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index cdf9a5880..5811e7a43 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -125,10 +125,12 @@ def test_process_class(self): self.assertEqual(Status.FINALIZED, target.inner[1].status) def test_process_classes(self): - target = ClassFactory.create( - attrs=[AttrFactory.reference("enumeration", forward=True)], - inner=[ClassFactory.enumeration(2, qname="enumeration")], - ) + target = ClassFactory.create() + inner = ClassFactory.enumeration(2, qname="enumeration") + + target.inner.append(inner) + inner.parent = target + target.attrs.append(AttrFactory.reference("enumeration", forward=True)) self.container.add(target) self.container.process_classes(Steps.FLATTEN) diff --git a/xsdata/codegen/handlers/unnest_inner_classes.py b/xsdata/codegen/handlers/unnest_inner_classes.py index 564fe71ed..9f42cfcb5 100644 --- a/xsdata/codegen/handlers/unnest_inner_classes.py +++ b/xsdata/codegen/handlers/unnest_inner_classes.py @@ -1,7 +1,9 @@ -from typing import Optional +from collections import defaultdict +from typing import Iterator, List, Tuple from xsdata.codegen.mixins import RelativeHandlerInterface -from xsdata.codegen.models import Attr, Class +from xsdata.codegen.models import AttrType, Class +from xsdata.codegen.utils import ClassUtils from xsdata.utils.namespaces import build_qname @@ -11,97 +13,91 @@ class UnnestInnerClasses(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): - """Process entrypoint for classes. - - Process the target class inner classes recursively. - - All enumerations are promoted by default, otherwise - only if the configuration is disabled the classes - are ignored. + """Promote all inner classes recursively. Args: - target: The target class instance to inspect + target: The target class instance to process """ - for inner in target.inner.copy(): - if inner.is_enumeration or self.container.config.output.unnest_classes: - self.promote(target, inner) + inner_classes = {} + inner_references = defaultdict(list) + promote_all = self.container.config.output.unnest_classes + for attr_type, source in self.find_forward_refs(target): + inner = ClassUtils.find_nested(source, attr_type.qname) + + if not (promote_all or inner.is_enumeration): + continue + + inner_classes[inner.ref] = inner + inner_references[inner.ref].append(attr_type) + + for ref, inner in inner_classes.items(): + references = inner_references[ref] - def promote(self, target: Class, inner: Class): - """Promote the inner class to root classes. + self.update_inner_class(inner) + self.update_types(references, inner.qname) + self.container.add(inner) - Steps: - - Replace forward references to the inner class - - Remove inner class from target class - - Copy the class to the global class container. + self.remove_orphan_inner_classes(target, promote_all) + + @classmethod + def remove_orphan_inner_classes(cls, target: Class, promote_all: bool): + """Remove inner classes with no attr references. Args: - target: The target class - inner: An inner class + target: The target class instance to process + promote_all: Whether to remove all inner classes or just the enumerations """ - target.inner.remove(inner) - attr = self.find_forward_attr(target, inner.qname) - if attr: - clone = self.clone_class(inner, target.name) - self.update_types(attr, inner.qname, clone) - self.container.add(clone) + for inner in target.inner.copy(): + if promote_all or inner.is_enumeration: + target.inner.remove(inner) @classmethod - def clone_class(cls, inner: Class, name: str) -> Class: - """Clone and prepare inner class for promotion. - - Clone the inner class, mark it as promoted and pref - the qualified name with the parent class name. + def find_forward_refs(cls, target: Class) -> Iterator[Tuple[AttrType, Class]]: + """Find all forward references for all inner classes. Args: - inner: The inner class to clone and prepare - name: The parent class name to use a prefix + target: The target class instance to process - Returns: - The new class instance + Yields: + A tuple of attr type and the parent class instance. """ - clone = inner.clone() - clone.parent = None - clone.local_type = True - clone.qname = build_qname(inner.target_namespace, f"{name}_{inner.name}") - - for attr in clone.attrs: + for attr in target.attrs: for tp in attr.types: - if tp.circular and tp.qname == inner.qname: - tp.qname = clone.qname - tp.reference = clone.ref + if tp.forward and not tp.native: + yield tp, target - return clone + for inner in target.inner: + yield from cls.find_forward_refs(inner) @classmethod - def update_types(cls, attr: Attr, search: str, source: Class): - """Update the references from an inner to a global class. + def update_inner_class(cls, target: Class): + """Prepare the nested class to be added as root. Args: - attr: The target attr to inspect and update - search: The current inner class qname - source: The new global class qname + target: The target class """ - for attr_type in attr.types: - if attr_type.qname == search and attr_type.forward: - attr_type.qname = source.qname - attr_type.reference = source.ref - attr_type.forward = False + assert target.parent is not None + name_parts = [target.parent.name, target.name] + new_qname = build_qname(target.target_namespace, "_".join(name_parts)) + + target.qname = new_qname + + assert target.parent is not None + + target.parent.inner.remove(target) + target.parent = None + target.local_type = True @classmethod - def find_forward_attr(cls, target: Class, qname: str) -> Optional[Attr]: - """Find the first attr that references the given inner class qname. + def update_types(cls, types: List[AttrType], qname: str): + """Search and replace forward references. - Args: - target: The target class instance - qname: An inner class qualified name + Return the number changes. - Returns: - Attr: The first attr that references the given qname - None: If no such attr exists, it can happen! + Args: + types: The types to search and replace + qname: The updated qname """ - for attr in target.attrs: - for attr_type in attr.types: - if attr_type.forward and attr_type.qname == qname: - return attr - - return None + for tp in types: + tp.qname = qname + tp.forward = False