Skip to content

Commit

Permalink
fix: Unnest classes should update inner classes recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Jun 15, 2024
1 parent 791a98e commit c022af6
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 186 deletions.
194 changes: 82 additions & 112 deletions tests/codegen/handlers/test_unnest_inner_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,124 +15,94 @@ def setUp(self):
self.container = ClassContainer(config=GeneratorConfig())
self.processor = UnnestInnerClasses(container=self.container)

def test_process(self):
def test_process_with_config_enabled(self):
self.container.config.output.unnest_classes = True
a = ClassFactory.create()
b = ClassFactory.create()
c = ClassFactory.create()

a.attrs.append(AttrFactory.reference(b.qname, forward=True))
b.attrs.append(AttrFactory.reference(c.qname, forward=True))
c.attrs.append(AttrFactory.reference(b.qname, forward=True))

a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.container.add(a)
self.processor.process(a)
self.assertEqual(3, len(list(self.container)))

self.assertEqual(b.qname, a.attrs[0].types[0].qname)
self.assertEqual(c.qname, b.attrs[0].types[0].qname)
self.assertEqual(b.qname, c.attrs[0].types[0].qname)

def test_process_with_config_disabled_promotes_only_enumerations(self):
self.container.config.output.unnest_classes = False
a = ClassFactory.create()
b = ClassFactory.create()
c = ClassFactory.enumeration(2)

enumeration = ClassFactory.enumeration(2)
local_type = ClassFactory.elements(2)
target = ClassFactory.create()
a.attrs.append(AttrFactory.reference(b.qname, forward=True))
b.attrs.append(AttrFactory.reference(c.qname, forward=True))

target.inner.append(enumeration)
target.inner.append(local_type)
self.container.add(target)
a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.processor.process(target)
self.container.add(a)
self.processor.process(a)

self.assertEqual(1, len(target.inner))
self.assertTrue(local_type in target.inner)
self.assertEqual(2, len(list(self.container)))
self.assertEqual(c.qname, b.attrs[0].types[0].qname)
self.assertEqual(1, len(a.inner))
self.assertEqual(0, len(b.inner))

def test_process_with_orphan_nested_class(self):
self.container.config.output.unnest_classes = True
self.processor.process(target)
self.assertEqual(0, len(target.inner))

def test_promote_with_orphan_inner(self):
inner = ClassFactory.elements(2)
target = ClassFactory.create()
target.inner.append(inner)
self.container.add(target)

self.processor.promote(target, inner)

self.assertEqual(0, len(target.inner))
self.assertEqual(1, len(self.container.data))

def test_promote_updates_forward_attr_types(self):
inner = ClassFactory.elements(2)
attr = AttrFactory.reference(inner.qname, forward=True)
target = ClassFactory.create()
target.attrs.append(attr)
target.inner.append(inner)
self.container.add(target)

self.processor.promote(target, inner)

self.assertEqual(0, len(target.inner))
self.assertEqual(2, len(self.container.data))
self.assertFalse(attr.types[0].forward)
self.assertEqual("{xsdata}class_C_class_B", attr.types[0].qname)

def test_clone_class(self):
target = ClassFactory.create(qname="{a}b")
actual = self.processor.clone_class(target, "parent")

self.assertIsNot(target, actual)
self.assertTrue(actual.local_type)
self.assertEqual("{a}parent_b", actual.qname)

def test_clone_class_with_circular_reference(self):
target = ClassFactory.create(qname="{a}b")
target.attrs.append(
AttrFactory.create(
name="self",
types=[AttrTypeFactory.create(qname=target.qname, circular=True)],
)
)

actual = self.processor.clone_class(target, "parent")

self.assertIsNot(target, actual)
self.assertTrue(actual.local_type)
self.assertEqual("{a}parent_b", actual.qname)

self.assertTrue(actual.attrs[0].types[0].circular)
self.assertEqual(actual.ref, actual.attrs[0].types[0].reference)
self.assertEqual(actual.qname, actual.attrs[0].types[0].qname)
a = ClassFactory.create()
b = ClassFactory.create()
c = ClassFactory.create()

a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.container.add(a)
self.processor.process(a)
self.assertEqual(1, len(list(self.container)))
self.assertEqual(0, len(a.inner))
self.assertEqual(1, len(b.inner))

def test_update_inner_class(self):
a = ClassFactory.create(qname="a")
b = ClassFactory.create(qname="b")
c = ClassFactory.create(qname="c")
a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.processor.update_inner_class(c)

self.assertEqual("b_c", c.qname)
self.assertTrue(c.local_type)
self.assertIsNone(c.parent)
self.assertEqual(0, len(b.inner))

def test_update_types(self):
attr = AttrFactory.create(
types=[
AttrTypeFactory.create(qname="a", forward=True),
AttrTypeFactory.create(qname="a", forward=False),
AttrTypeFactory.create(qname="b", forward=False),
]
)
source = ClassFactory.create()

self.processor.update_types(attr, "a", source)

self.assertEqual(source.qname, attr.types[0].qname)
self.assertFalse(attr.types[0].forward)
self.assertEqual(source.ref, attr.types[0].reference)

self.assertEqual("a", attr.types[1].qname)
self.assertFalse(attr.types[1].forward)
self.assertEqual("b", attr.types[2].qname)
self.assertFalse(attr.types[2].forward)

def test_find_forward_attr(self):
target = ClassFactory.create(
attrs=[
AttrFactory.create(
types=[
AttrTypeFactory.create("a", forward=False),
AttrTypeFactory.create("b", forward=False),
]
),
AttrFactory.create(
types=[
AttrTypeFactory.create("a", forward=True),
AttrTypeFactory.create("b", forward=True),
]
),
AttrFactory.create(),
]
)

actual = self.processor.find_forward_attr(target, "a")
self.assertEqual(target.attrs[1], actual)

actual = self.processor.find_forward_attr(target, "b")
self.assertEqual(target.attrs[1], actual)

actual = self.processor.find_forward_attr(target, "c")
self.assertIsNone(actual)
types = [
AttrTypeFactory.create(qname="a", forward=True),
AttrTypeFactory.create(qname="a", forward=True),
]

self.processor.update_types(types, "b")

self.assertEqual("b", types[0].qname)
self.assertFalse(types[0].forward)

self.assertEqual("b", types[1].qname)
self.assertFalse(types[1].forward)
10 changes: 6 additions & 4 deletions tests/codegen/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ def test_process_class(self):
self.assertEqual(Status.FINALIZED, target.inner[1].status)

def test_process_classes(self):
target = ClassFactory.create(
attrs=[AttrFactory.reference("enumeration", forward=True)],
inner=[ClassFactory.enumeration(2, qname="enumeration")],
)
target = ClassFactory.create()
inner = ClassFactory.enumeration(2, qname="enumeration")

target.inner.append(inner)
inner.parent = target
target.attrs.append(AttrFactory.reference("enumeration", forward=True))

self.container.add(target)
self.container.process_classes(Steps.FLATTEN)
Expand Down
136 changes: 66 additions & 70 deletions xsdata/codegen/handlers/unnest_inner_classes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional
from collections import defaultdict
from typing import Iterator, List, Tuple

from xsdata.codegen.mixins import RelativeHandlerInterface
from xsdata.codegen.models import Attr, Class
from xsdata.codegen.models import AttrType, Class
from xsdata.codegen.utils import ClassUtils
from xsdata.utils.namespaces import build_qname


Expand All @@ -11,97 +13,91 @@ class UnnestInnerClasses(RelativeHandlerInterface):
__slots__ = ()

def process(self, target: Class):
"""Process entrypoint for classes.
Process the target class inner classes recursively.
All enumerations are promoted by default, otherwise
only if the configuration is disabled the classes
are ignored.
"""Promote all inner classes recursively.
Args:
target: The target class instance to inspect
target: The target class instance to process
"""
for inner in target.inner.copy():
if inner.is_enumeration or self.container.config.output.unnest_classes:
self.promote(target, inner)
inner_classes = {}
inner_references = defaultdict(list)
promote_all = self.container.config.output.unnest_classes
for attr_type, source in self.find_forward_refs(target):
inner = ClassUtils.find_nested(source, attr_type.qname)

if not (promote_all or inner.is_enumeration):
continue

inner_classes[inner.ref] = inner
inner_references[inner.ref].append(attr_type)

for ref, inner in inner_classes.items():
references = inner_references[ref]

def promote(self, target: Class, inner: Class):
"""Promote the inner class to root classes.
self.update_inner_class(inner)
self.update_types(references, inner.qname)
self.container.add(inner)

Steps:
- Replace forward references to the inner class
- Remove inner class from target class
- Copy the class to the global class container.
self.remove_orphan_inner_classes(target, promote_all)

@classmethod
def remove_orphan_inner_classes(cls, target: Class, promote_all: bool):
"""Remove inner classes with no attr references.
Args:
target: The target class
inner: An inner class
target: The target class instance to process
promote_all: Whether to remove all inner classes or just the enumerations
"""
target.inner.remove(inner)
attr = self.find_forward_attr(target, inner.qname)
if attr:
clone = self.clone_class(inner, target.name)
self.update_types(attr, inner.qname, clone)
self.container.add(clone)
for inner in target.inner.copy():
if promote_all or inner.is_enumeration:
target.inner.remove(inner)

@classmethod
def clone_class(cls, inner: Class, name: str) -> Class:
"""Clone and prepare inner class for promotion.
Clone the inner class, mark it as promoted and pref
the qualified name with the parent class name.
def find_forward_refs(cls, target: Class) -> Iterator[Tuple[AttrType, Class]]:
"""Find all forward references for all inner classes.
Args:
inner: The inner class to clone and prepare
name: The parent class name to use a prefix
target: The target class instance to process
Returns:
The new class instance
Yields:
A tuple of attr type and the parent class instance.
"""
clone = inner.clone()
clone.parent = None
clone.local_type = True
clone.qname = build_qname(inner.target_namespace, f"{name}_{inner.name}")

for attr in clone.attrs:
for attr in target.attrs:
for tp in attr.types:
if tp.circular and tp.qname == inner.qname:
tp.qname = clone.qname
tp.reference = clone.ref
if tp.forward and not tp.native:
yield tp, target

return clone
for inner in target.inner:
yield from cls.find_forward_refs(inner)

@classmethod
def update_types(cls, attr: Attr, search: str, source: Class):
"""Update the references from an inner to a global class.
def update_inner_class(cls, target: Class):
"""Prepare the nested class to be added as root.
Args:
attr: The target attr to inspect and update
search: The current inner class qname
source: The new global class qname
target: The target class
"""
for attr_type in attr.types:
if attr_type.qname == search and attr_type.forward:
attr_type.qname = source.qname
attr_type.reference = source.ref
attr_type.forward = False
assert target.parent is not None
name_parts = [target.parent.name, target.name]
new_qname = build_qname(target.target_namespace, "_".join(name_parts))

target.qname = new_qname

assert target.parent is not None

target.parent.inner.remove(target)
target.parent = None
target.local_type = True

@classmethod
def find_forward_attr(cls, target: Class, qname: str) -> Optional[Attr]:
"""Find the first attr that references the given inner class qname.
def update_types(cls, types: List[AttrType], qname: str):
"""Search and replace forward references.
Args:
target: The target class instance
qname: An inner class qualified name
Return the number changes.
Returns:
Attr: The first attr that references the given qname
None: If no such attr exists, it can happen!
Args:
types: The types to search and replace
qname: The updated qname
"""
for attr in target.attrs:
for attr_type in attr.types:
if attr_type.forward and attr_type.qname == qname:
return attr

return None
for tp in types:
tp.qname = qname
tp.forward = False

0 comments on commit c022af6

Please sign in to comment.