Skip to content

Commit

Permalink
fix: Unnest classes should update inner classes recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed May 27, 2024
1 parent cbec5fa commit e434101
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 184 deletions.
199 changes: 87 additions & 112 deletions tests/codegen/handlers/test_unnest_inner_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,124 +15,99 @@ def setUp(self):
self.container = ClassContainer(config=GeneratorConfig())
self.processor = UnnestInnerClasses(container=self.container)

def test_process(self):
def test_process_with_config_enabled(self):
self.container.config.output.unnest_classes = True
a = ClassFactory.create()
b = ClassFactory.create()
c = ClassFactory.create()

a.attrs.append(AttrFactory.reference(b.qname, forward=True))
b.attrs.append(AttrFactory.reference(c.qname, forward=True))
c.attrs.append(AttrFactory.reference(b.qname, forward=True))

a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.container.add(a)
self.processor.process(a)
self.assertEqual(3, len(list(self.container)))

self.assertEqual(b.qname, a.attrs[0].types[0].qname)
self.assertEqual(c.qname, b.attrs[0].types[0].qname)
self.assertEqual(b.qname, c.attrs[0].types[0].qname)

def test_process_with_config_disabled_promotes_only_enumerations(self):
self.container.config.output.unnest_classes = False
a = ClassFactory.create()
b = ClassFactory.create()
c = ClassFactory.enumeration(2)

enumeration = ClassFactory.enumeration(2)
local_type = ClassFactory.elements(2)
target = ClassFactory.create()
a.attrs.append(AttrFactory.reference(b.qname, forward=True))
b.attrs.append(AttrFactory.reference(c.qname, forward=True))

target.inner.append(enumeration)
target.inner.append(local_type)
self.container.add(target)
a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.processor.process(target)
self.container.add(a)
self.processor.process(a)

self.assertEqual(1, len(target.inner))
self.assertTrue(local_type in target.inner)
self.assertEqual(2, len(list(self.container)))
self.assertEqual(c.qname, b.attrs[0].types[0].qname)
self.assertEqual(1, len(a.inner))
self.assertEqual(0, len(b.inner))

def test_process_with_orphan_nested_class(self):
self.container.config.output.unnest_classes = True
self.processor.process(target)
self.assertEqual(0, len(target.inner))

def test_promote_with_orphan_inner(self):
inner = ClassFactory.elements(2)
target = ClassFactory.create()
target.inner.append(inner)
self.container.add(target)

self.processor.promote(target, inner)

self.assertEqual(0, len(target.inner))
self.assertEqual(1, len(self.container.data))

def test_promote_updates_forward_attr_types(self):
inner = ClassFactory.elements(2)
attr = AttrFactory.reference(inner.qname, forward=True)
target = ClassFactory.create()
target.attrs.append(attr)
target.inner.append(inner)
self.container.add(target)

self.processor.promote(target, inner)

self.assertEqual(0, len(target.inner))
self.assertEqual(2, len(self.container.data))
self.assertFalse(attr.types[0].forward)
self.assertEqual("{xsdata}class_C_class_B", attr.types[0].qname)

def test_clone_class(self):
target = ClassFactory.create(qname="{a}b")
actual = self.processor.clone_class(target, "parent")

self.assertIsNot(target, actual)
self.assertTrue(actual.local_type)
self.assertEqual("{a}parent_b", actual.qname)

def test_clone_class_with_circular_reference(self):
target = ClassFactory.create(qname="{a}b")
target.attrs.append(
AttrFactory.create(
name="self",
types=[AttrTypeFactory.create(qname=target.qname, circular=True)],
)
)

actual = self.processor.clone_class(target, "parent")

self.assertIsNot(target, actual)
self.assertTrue(actual.local_type)
self.assertEqual("{a}parent_b", actual.qname)

self.assertTrue(actual.attrs[0].types[0].circular)
self.assertEqual(actual.ref, actual.attrs[0].types[0].reference)
self.assertEqual(actual.qname, actual.attrs[0].types[0].qname)
a = ClassFactory.create()
b = ClassFactory.create()
c = ClassFactory.create()

a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.container.add(a)
self.processor.process(a)
self.assertEqual(1, len(list(self.container)))
self.assertEqual(0, len(a.inner))
self.assertEqual(0, len(b.inner))

def test_update_nested_class(self):
a = ClassFactory.create(qname="a")
b = ClassFactory.create(qname="b")
c = ClassFactory.create(qname="c")
a.inner.append(b)
b.inner.append(c)
b.parent = a
c.parent = b

self.processor.update_nested_class(c)

self.assertEqual("b_c", c.qname)
self.assertTrue(c.local_type)
self.assertIsNone(c.parent)
self.assertEqual(0, len(b.inner))

def test_update_types(self):
attr = AttrFactory.create(
types=[
AttrTypeFactory.create(qname="a", forward=True),
AttrTypeFactory.create(qname="a", forward=False),
AttrTypeFactory.create(qname="b", forward=False),
]
)
source = ClassFactory.create()

self.processor.update_types(attr, "a", source)

self.assertEqual(source.qname, attr.types[0].qname)
self.assertFalse(attr.types[0].forward)
self.assertEqual(source.ref, attr.types[0].reference)

self.assertEqual("a", attr.types[1].qname)
self.assertFalse(attr.types[1].forward)
self.assertEqual("b", attr.types[2].qname)
self.assertFalse(attr.types[2].forward)

def test_find_forward_attr(self):
target = ClassFactory.create(
attrs=[
AttrFactory.create(
types=[
AttrTypeFactory.create("a", forward=False),
AttrTypeFactory.create("b", forward=False),
]
),
AttrFactory.create(
types=[
AttrTypeFactory.create("a", forward=True),
AttrTypeFactory.create("b", forward=True),
]
),
AttrFactory.create(),
]
)

actual = self.processor.find_forward_attr(target, "a")
self.assertEqual(target.attrs[1], actual)

actual = self.processor.find_forward_attr(target, "b")
self.assertEqual(target.attrs[1], actual)

actual = self.processor.find_forward_attr(target, "c")
self.assertIsNone(actual)
types = [
AttrTypeFactory.create(qname="a", forward=True),
AttrTypeFactory.create(qname="a", forward=False),
AttrTypeFactory.create(qname="b", forward=True),
]

result = self.processor.update_types(types, "a", "b")
self.assertEqual(1, result)

self.assertEqual("b", types[0].qname)
self.assertFalse(types[0].forward)

self.assertEqual("a", types[1].qname)
self.assertFalse(types[1].forward)

self.assertEqual("b", types[2].qname)
self.assertTrue(types[2].forward)
10 changes: 6 additions & 4 deletions tests/codegen/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ def test_process_class(self):
self.assertEqual(Status.FINALIZED, target.inner[1].status)

def test_process_classes(self):
target = ClassFactory.create(
attrs=[AttrFactory.reference("enumeration", forward=True)],
inner=[ClassFactory.enumeration(2, qname="enumeration")],
)
target = ClassFactory.create()
inner = ClassFactory.enumeration(2, qname="enumeration")

target.inner.append(inner)
inner.parent = target
target.attrs.append(AttrFactory.reference("enumeration", forward=True))

self.container.add(target)
self.container.process_classes(Steps.FLATTEN)
Expand Down
112 changes: 44 additions & 68 deletions xsdata/codegen/handlers/unnest_inner_classes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from typing import Iterator, List

from xsdata.codegen.mixins import RelativeHandlerInterface
from xsdata.codegen.models import Attr, Class
from xsdata.codegen.models import AttrType, Class
from xsdata.utils.namespaces import build_qname


Expand All @@ -22,86 +22,62 @@ def process(self, target: Class):
Args:
target: The target class instance to inspect
"""
for inner in target.inner.copy():
if inner.is_enumeration or self.container.config.output.unnest_classes:
self.promote(target, inner)
all_types = list(target.types())
promote_all = self.container.config.output.unnest_classes
for nested in list(self.find_nested_classes(target)):
if not (promote_all or nested.is_enumeration):
continue

def promote(self, target: Class, inner: Class):
"""Promote the inner class to root classes.
Steps:
- Replace forward references to the inner class
- Remove inner class from target class
- Copy the class to the global class container.
Args:
target: The target class
inner: An inner class
"""
target.inner.remove(inner)
attr = self.find_forward_attr(target, inner.qname)
if attr:
clone = self.clone_class(inner, target.name)
self.update_types(attr, inner.qname, clone)
self.container.add(clone)
cur_qname = nested.qname
self.update_nested_class(nested)
if self.update_types(all_types, cur_qname, nested.qname):
self.container.add(nested)

@classmethod
def clone_class(cls, inner: Class, name: str) -> Class:
"""Clone and prepare inner class for promotion.
def find_nested_classes(cls, target: Class) -> Iterator[Class]:
"""Return all inner classes recursively."""
for inner in target.inner:
yield inner
yield from cls.find_nested_classes(inner)

Clone the inner class, mark it as promoted and pref
the qualified name with the parent class name.
@classmethod
def update_nested_class(cls, target: Class):
"""Prepare the nested class to be added as root.
Args:
inner: The inner class to clone and prepare
name: The parent class name to use a prefix
Returns:
The new class instance
target: The target class
"""
clone = inner.clone()
clone.parent = None
clone.local_type = True
clone.qname = build_qname(inner.target_namespace, f"{name}_{inner.name}")
assert target.parent is not None
name_parts = [target.parent.name, target.name]
new_qname = build_qname(target.target_namespace, "_".join(name_parts))

for attr in clone.attrs:
for tp in attr.types:
if tp.circular and tp.qname == inner.qname:
tp.qname = clone.qname
tp.reference = clone.ref
target.qname = new_qname

return clone
assert target.parent is not None

@classmethod
def update_types(cls, attr: Attr, search: str, source: Class):
"""Update the references from an inner to a global class.
Args:
attr: The target attr to inspect and update
search: The current inner class qname
source: The new global class qname
"""
for attr_type in attr.types:
if attr_type.qname == search and attr_type.forward:
attr_type.qname = source.qname
attr_type.reference = source.ref
attr_type.forward = False
target.parent.inner.remove(target)
target.parent = None
target.local_type = True

@classmethod
def find_forward_attr(cls, target: Class, qname: str) -> Optional[Attr]:
"""Find the first attr that references the given inner class qname.
def update_types(cls, types: List[AttrType], search: str, replace: str) -> int:
"""Search and replace forward references.
Return the number changes.
Args:
target: The target class instance
qname: An inner class qualified name
types: The types to search and replace
search: The search qname
replace: The replacement qname
Returns:
Attr: The first attr that references the given qname
None: If no such attr exists, it can happen!
The number of changed attr types
"""
for attr in target.attrs:
for attr_type in attr.types:
if attr_type.forward and attr_type.qname == qname:
return attr

return None
updated = 0
for tp in types:
if tp.forward and tp.qname == search:
tp.qname = replace
tp.forward = False
updated += 1

return updated

0 comments on commit e434101

Please sign in to comment.