From b6afc789fbe7bf23a43742f1bf480bb50d188710 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Tue, 16 Apr 2024 19:59:43 +0300 Subject: [PATCH 1/4] feat: Navigate from inner to outer classes --- .../handlers/test_validate_references.py | 34 +++++++++++++++++ tests/codegen/test_container.py | 3 ++ .../codegen/handlers/disambiguate_choices.py | 1 + .../codegen/handlers/unnest_inner_classes.py | 1 + .../codegen/handlers/validate_references.py | 37 ++++++++++++++++++- xsdata/codegen/mappers/definitions.py | 1 + xsdata/codegen/mappers/dict.py | 1 + xsdata/codegen/mappers/dtd.py | 2 +- xsdata/codegen/mappers/element.py | 1 + xsdata/codegen/mappers/schema.py | 1 + xsdata/codegen/models.py | 2 + xsdata/codegen/utils.py | 2 + 12 files changed, 84 insertions(+), 2 deletions(-) diff --git a/tests/codegen/handlers/test_validate_references.py b/tests/codegen/handlers/test_validate_references.py index 18b35d8c9..da773b6b9 100644 --- a/tests/codegen/handlers/test_validate_references.py +++ b/tests/codegen/handlers/test_validate_references.py @@ -60,3 +60,37 @@ def test_validate_misrepresented_references(self): with self.assertRaises(CodegenError): self.handler.run() + + def test_validate_parent_references_with_root_class_with_parent(self): + target = ClassFactory.create() + target.parent = ClassFactory.create() + self.container.add(target) + + with self.assertRaises(CodegenError): + self.handler.run() + + def test_validate_parent_references_with_wrong_parent(self): + parent = ClassFactory.create() + child = ClassFactory.create() + wrong = ClassFactory.create() + + parent.inner.append(child) + child.parent = wrong + + self.container.extend([parent, wrong]) + + with self.assertRaises(CodegenError): + self.handler.run() + + def test_validate_parent_references_with_wrong_parent_ref(self): + parent = ClassFactory.create() + child = ClassFactory.create() + wrong = parent.clone() + + parent.inner.append(child) + child.parent = wrong + + self.container.extend([parent]) + + with self.assertRaises(CodegenError): + self.handler.run() diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index a4804d80f..cdf9a5880 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -114,6 +114,9 @@ def test_process_class(self): target = ClassFactory.create( inner=[ClassFactory.elements(2), ClassFactory.elements(1)] ) + for inner in target.inner: + inner.parent = target + self.container.add(target) self.container.process() diff --git a/xsdata/codegen/handlers/disambiguate_choices.py b/xsdata/codegen/handlers/disambiguate_choices.py index d491abf72..f94628a0e 100644 --- a/xsdata/codegen/handlers/disambiguate_choices.py +++ b/xsdata/codegen/handlers/disambiguate_choices.py @@ -171,6 +171,7 @@ def disambiguate_choice(self, target: Class, choice: Attr): if not inner: self.container.add(ref_class) else: + ref_class.parent = target target.inner.append(ref_class) def is_simple_type(self, choice: Attr) -> bool: diff --git a/xsdata/codegen/handlers/unnest_inner_classes.py b/xsdata/codegen/handlers/unnest_inner_classes.py index 4b38d8cf1..bc1d8e751 100644 --- a/xsdata/codegen/handlers/unnest_inner_classes.py +++ b/xsdata/codegen/handlers/unnest_inner_classes.py @@ -60,6 +60,7 @@ def clone_class(cls, inner: Class, name: str) -> Class: The new class instance """ clone = inner.clone() + clone.parent = None clone.local_type = True clone.qname = build_qname(inner.target_namespace, f"{name}_{inner.name}") diff --git a/xsdata/codegen/handlers/validate_references.py b/xsdata/codegen/handlers/validate_references.py index 5403ab6b1..de6dff404 100644 --- a/xsdata/codegen/handlers/validate_references.py +++ b/xsdata/codegen/handlers/validate_references.py @@ -1,4 +1,4 @@ -from typing import Set +from typing import Optional, Set from xsdata.codegen.exceptions import CodegenError from xsdata.codegen.mixins import ContainerHandlerInterface @@ -21,6 +21,7 @@ def run(self): self.validate_unique_qualified_names() self.validate_unique_instances() self.validate_resolved_references() + self.validate_parent_references() def validate_unique_qualified_names(self): """Validate all root classes have unique qualified names.""" @@ -71,3 +72,37 @@ def build(target: Class): raise CodegenError( "Misrepresented reference", cls=item.qname, type=tp.qname ) + + def validate_parent_references(self): + """Validate inner to outer classes is accurate.""" + + def _validate(target: Class, parent: Optional[Class] = None): + actual_qname = actual_ref = expected_qname = expected_ref = None + if target.parent: + actual_qname = target.parent.qname + actual_ref = target.parent.ref + + if parent: + expected_qname = parent.qname + expected_ref = parent.ref + + if actual_qname != expected_qname: + raise CodegenError( + "Invalid parent class reference", + cls=target.qname, + expected=expected_qname, + actual=actual_qname, + ) + + if actual_ref != expected_ref: + raise CodegenError( + "Invalid parent class reference", + cls=target.qname, + ref=actual_qname, + ) + + for inner in target.inner: + _validate(inner, target) + + for item in self.container: + _validate(item) diff --git a/xsdata/codegen/mappers/definitions.py b/xsdata/codegen/mappers/definitions.py index a7effce21..1a5f8a884 100644 --- a/xsdata/codegen/mappers/definitions.py +++ b/xsdata/codegen/mappers/definitions.py @@ -376,6 +376,7 @@ def build_inner_class( ) attr = cls.build_attr(name, inner.qname, forward=True, namespace=namespace) + inner.parent = target target.inner.append(inner) target.attrs.append(attr) diff --git a/xsdata/codegen/mappers/dict.py b/xsdata/codegen/mappers/dict.py index a62be0ba4..8aae9b09d 100644 --- a/xsdata/codegen/mappers/dict.py +++ b/xsdata/codegen/mappers/dict.py @@ -66,6 +66,7 @@ def build_class_attribute(cls, target: Class, name: str, value: Any): else: if isinstance(value, dict): inner = cls.build_class(value, name) + inner.parent = target attr_type = AttrType(qname=inner.qname, forward=True) target.inner.append(inner) else: diff --git a/xsdata/codegen/mappers/dtd.py b/xsdata/codegen/mappers/dtd.py index 33644e0d0..016ccab1f 100644 --- a/xsdata/codegen/mappers/dtd.py +++ b/xsdata/codegen/mappers/dtd.py @@ -327,5 +327,5 @@ def build_enumeration(cls, target: Class, name: str, values: List[str]): types=[attr_type.clone()], ) ) - + inner.parent = target target.inner.append(inner) diff --git a/xsdata/codegen/mappers/element.py b/xsdata/codegen/mappers/element.py index 78ad64f03..0c319477e 100644 --- a/xsdata/codegen/mappers/element.py +++ b/xsdata/codegen/mappers/element.py @@ -106,6 +106,7 @@ def build_elements( if child.attributes or child.children: inner = cls.build_class(child, namespace) + inner.parent = target attr_type = AttrType(qname=inner.qname, forward=True) target.inner.append(inner) else: diff --git a/xsdata/codegen/mappers/schema.py b/xsdata/codegen/mappers/schema.py index d5c145e30..2fc9ed23a 100644 --- a/xsdata/codegen/mappers/schema.py +++ b/xsdata/codegen/mappers/schema.py @@ -360,6 +360,7 @@ def build_attr_types(cls, target: Class, obj: ElementBase) -> List[AttrType]: location = target.location namespace = target.target_namespace for inner in cls.build_inner_classes(obj, location, namespace): + inner.parent = target target.inner.append(inner) types.append(AttrType(qname=inner.qname, forward=True)) diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 3acbd32b1..ed9351214 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -512,6 +512,7 @@ class Class(CodegenModel): attrs: The list of all the attr instances inner: The list of all the inner class instances ns_map: The namespace prefix-URI map + parent: The parent outer class """ qname: str @@ -535,6 +536,7 @@ class Class(CodegenModel): attrs: List[Attr] = field(default_factory=list) inner: List["Class"] = field(default_factory=list) ns_map: Dict = field(default_factory=dict) + parent: Optional["Class"] = field(default=None, compare=False) @property def name(self) -> str: diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 5efedaedb..92d98ac0a 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -206,6 +206,7 @@ def copy_inner_class(cls, source: Class, target: Class, attr_type: AttrType): clone.module = target.module clone.status = Status.RAW attr_type.reference = clone.ref + clone.parent = target target.inner.append(clone) @classmethod @@ -259,6 +260,7 @@ def flatten(cls, target: Class, location: str) -> Iterator[Class]: An iterator over all the found classes. """ target.location = location + target.parent = None while target.inner: yield from cls.flatten(target.inner.pop(), location) From 21e09ce776b3f984a688b1588db09b8cce38a434 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sat, 20 Apr 2024 17:12:16 +0300 Subject: [PATCH 2/4] feat: Use bidirectional classes navigation in filters --- tests/codegen/test_utils.py | 25 +++ tests/formats/dataclass/test_filters.py | 209 +++++++++--------- xsdata/codegen/models.py | 10 + xsdata/codegen/utils.py | 44 +++- xsdata/formats/dataclass/filters.py | 52 ++--- .../formats/dataclass/templates/class.jinja2 | 8 +- 6 files changed, 211 insertions(+), 137 deletions(-) diff --git a/tests/codegen/test_utils.py b/tests/codegen/test_utils.py index d34540a21..6dd873a66 100644 --- a/tests/codegen/test_utils.py +++ b/tests/codegen/test_utils.py @@ -400,3 +400,28 @@ def test_filter_types(self): types = [xs_any] actual = ClassUtils.filter_types(types) self.assertEqual(1, len(actual)) + + def test_find_nested(self): + a = ClassFactory.create(qname="a") + b = ClassFactory.create(qname="b") + c = ClassFactory.create(qname="c") + + a.inner.append(b) + b.inner.append(c) + c.parent = b + b.parent = a + + self.assertEqual(a, ClassUtils.find_nested(a, "a")) + self.assertEqual(b, ClassUtils.find_nested(a, "b")) + self.assertEqual(b, ClassUtils.find_nested(c, "b")) + self.assertEqual(a, ClassUtils.find_nested(c, "a")) + + a2 = ClassFactory.create(qname="a") + c.inner.append(a2) + a2.parent = c + + # Breadth-first search + self.assertEqual(a2, ClassUtils.find_nested(c, "a")) + + with self.assertRaises(CodegenError): + ClassUtils.find_nested(a, "nope") diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index ac8e613b2..42f870826 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -39,6 +39,24 @@ def setUp(self) -> None: config = GeneratorConfig() self.filters = Filters(config) + obj = ClassFactory.create(qname="a") + obj_nested = ClassFactory.create(qname="b") + obj_nested_nested = ClassFactory.create(qname="c") + obj_nested_nested_nested = ClassFactory.create(qname="d") + + obj_nested_nested_nested.parent = obj_nested_nested + obj_nested_nested.parent = obj_nested + obj_nested.parent = obj + + obj.inner.append(obj_nested) + obj_nested.inner.append(obj_nested_nested) + obj_nested_nested.inner.append(obj_nested_nested_nested) + + self.obj = obj + self.obj_nested = obj_nested + self.obj_nested_nested = obj_nested_nested + self.obj_nested_nested_nested = obj_nested_nested_nested + def test_class_name(self): self.filters.substitutions[ObjectType.CLASS]["Abc"] = "Cba" @@ -231,7 +249,7 @@ def test_field_definition(self, mock_field_default_value): mock_field_default_value.side_effect = [1, False] attr = AttrFactory.native(DataType.INT) - result = self.filters.field_definition(attr, {}, None, ["Root"]) + result = self.filters.field_definition(self.obj, attr, None) expected = ( "field(\n" " default=1,\n" @@ -243,7 +261,7 @@ def test_field_definition(self, mock_field_default_value): ) self.assertEqual(expected, result) - result = self.filters.field_definition(attr, {}, None, ["Root"]) + result = self.filters.field_definition(self.obj, attr, None) expected = ( "field(\n" " metadata={\n" @@ -259,7 +277,7 @@ def test_field_definition_with_prohibited_attr(self): attr.restrictions.max_occurs = 0 attr.default = "1" - result = self.filters.field_definition(attr, {}, None, ["Root"]) + result = self.filters.field_definition(self.obj, attr, None) expected = ( "field(\n" " init=False,\n" @@ -279,7 +297,7 @@ def test_field_definition_with_restriction_pattern(self, mock_field_default_valu pattern = '([^\\ \\? > < \\* / " ": |]{1,256})' str_attr.restrictions.pattern = pattern - result = self.filters.field_definition(str_attr, {}, None, ["Root"]) + result = self.filters.field_definition(self.obj, str_attr, None) expected = ( "field(\n" " default=None,\n" @@ -294,7 +312,7 @@ def test_field_definition_with_restriction_pattern(self, mock_field_default_valu def test_field_definition_without_metadata(self, mock_field_metadata): mock_field_metadata.return_value = {} str_attr = AttrFactory.create(types=[type_str], tag=Tag.RESTRICTION) - result = self.filters.field_definition(str_attr, {}, None, ["Root"]) + result = self.filters.field_definition(self.obj, str_attr, None) expected = "field(\n" " default=None\n" " )" self.assertEqual(expected, result) @@ -429,49 +447,48 @@ def test_field_default_value_with_multiple_types(self): def test_field_metadata(self): attr = AttrFactory.element() expected = {"name": "attr_B", "type": "Element"} - self.assertEqual(expected, self.filters.field_metadata(attr, None, ["cls"])) - self.assertEqual(expected, self.filters.field_metadata(attr, "foo", ["cls"])) + self.assertEqual(expected, self.filters.field_metadata(self.obj, attr, None)) def test_field_metadata_namespace(self): attr = AttrFactory.element(namespace="foo") expected = {"name": "attr_B", "namespace": "foo", "type": "Element"} - actual = self.filters.field_metadata(attr, None, ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, None) self.assertEqual(expected, actual) - actual = self.filters.field_metadata(attr, "foo", ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, "foo") self.assertNotIn("namespace", actual) attr = AttrFactory.attribute(namespace="foo") expected = {"name": "attr_C", "namespace": "foo", "type": "Attribute"} - actual = self.filters.field_metadata(attr, None, ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, None) self.assertEqual(expected, actual) - actual = self.filters.field_metadata(attr, "foo", ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, "foo") self.assertIn("namespace", actual) def test_field_metadata_name(self): attr = AttrFactory.element(name="bar") attr.local_name = "foo" - actual = self.filters.field_metadata(attr, None, ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, None) self.assertEqual("foo", actual["name"]) attr = AttrFactory.element(name="Foo") attr.local_name = "foo" - actual = self.filters.field_metadata(attr, None, ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, None) self.assertNotIn("name", actual) attr = AttrFactory.create(tag=Tag.ANY, name="bar") attr.local_name = "foo" - actual = self.filters.field_metadata(attr, None, ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, None) self.assertNotIn("name", actual) def test_field_metadata_wrapper(self): attr = AttrFactory.element(wrapper="foo") expected = {"name": "attr_B", "wrapper": "foo", "type": "Element"} - actual = self.filters.field_metadata(attr, None, ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, None) self.assertEqual(expected, actual) def test_field_metadata_restrictions(self): @@ -482,30 +499,30 @@ def test_field_metadata_restrictions(self): attr.restrictions.max_inclusive = "2" expected = {"min_occurs": 1, "max_occurs": 2, "max_inclusive": 2} - self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + self.assertEqual(expected, self.filters.field_metadata(self.obj, attr, None)) attr.restrictions.min_occurs = 1 attr.restrictions.max_occurs = 1 expected = {"required": True, "max_inclusive": 2} - self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + self.assertEqual(expected, self.filters.field_metadata(self.obj, attr, None)) attr.restrictions.nillable = True expected = {"nillable": True, "max_inclusive": 2} - self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + self.assertEqual(expected, self.filters.field_metadata(self.obj, attr, None)) attr.default = None attr.restrictions.tokens = True expected = {"max_inclusive": 2, "nillable": True, "tokens": True} - self.assertEqual(expected, self.filters.field_metadata(attr, None, [])) + self.assertEqual(expected, self.filters.field_metadata(self.obj, attr, None)) def test_field_metadata_mixed(self): attr = AttrFactory.element(mixed=True) expected = {"mixed": True, "name": "attr_B", "type": "Element"} - self.assertEqual(expected, self.filters.field_metadata(attr, "foo", ["cls"])) + self.assertEqual(expected, self.filters.field_metadata(self.obj, attr, "foo")) def test_field_metadata_choices(self): attr = AttrFactory.create(choices=AttrFactory.list(2, tag=Tag.ELEMENT)) - actual = self.filters.field_metadata(attr, "foo", ["cls"]) + actual = self.filters.field_metadata(self.obj, attr, "foo") expected = ( {"name": "attr_B", "type": "Type[str]"}, {"name": "attr_C", "type": "Type[str]"}, @@ -529,7 +546,7 @@ def test_field_choices(self): ] ) - actual = self.filters.field_choices(attr, "foo", ["a", "b"]) + actual = self.filters.field_choices(self.obj, attr, "foo") expected = ( {"name": "$", "type": "Type[float]", "max_exclusive": 10.0}, {"name": "attr_B", "namespace": "bar", "type": "Type[str]"}, @@ -551,7 +568,7 @@ def test_field_choices(self): self.filters.docstring_style = DocstringStyle.ACCESSIBLE attr.choices[0].help = "help" - actual = self.filters.field_choices(attr, None, []) + actual = self.filters.field_choices(self.obj, attr, None) self.assertEqual(attr.choices[0].help, actual[0]["doc"]) self.assertNotIn("doc", actual[1]) @@ -560,159 +577,138 @@ def test_field_type_with_default_value(self): default="1", types=AttrTypeFactory.list(1, qname="foo_bar") ) - self.assertEqual("FooBar", self.filters.field_type(attr, [])) + self.assertEqual("FooBar", self.filters.field_type(self.obj, attr)) attr.restrictions.nillable = True - self.assertEqual("Optional[FooBar]", self.filters.field_type(attr, [])) + self.assertEqual("Optional[FooBar]", self.filters.field_type(self.obj, attr)) self.filters.union_type = True - self.assertEqual("None | FooBar", self.filters.field_type(attr, [])) + self.assertEqual("None | FooBar", self.filters.field_type(self.obj, attr)) def test_field_type_with_optional_value(self): attr = AttrFactory.create(types=AttrTypeFactory.list(1, qname="foo_bar")) - self.assertEqual("Optional[FooBar]", self.filters.field_type(attr, [])) + self.assertEqual("Optional[FooBar]", self.filters.field_type(self.obj, attr)) self.filters.format.kw_only = True - self.assertEqual("FooBar", self.filters.field_type(attr, [])) + self.assertEqual("FooBar", self.filters.field_type(self.obj, attr)) attr.restrictions.min_occurs = 0 - self.assertEqual("Optional[FooBar]", self.filters.field_type(attr, [])) + self.assertEqual("Optional[FooBar]", self.filters.field_type(self.obj, attr)) self.filters.union_type = True - self.assertEqual("None | FooBar", self.filters.field_type(attr, [])) + self.assertEqual("None | FooBar", self.filters.field_type(self.obj, attr)) def test_field_type_with_circular_reference(self): attr = AttrFactory.create( - types=AttrTypeFactory.list(1, qname="foo_bar", circular=True) + types=AttrTypeFactory.list(1, qname="c", circular=True) ) self.assertEqual( - 'Optional["FooBar"]', self.filters.field_type(attr, ["Parent"]) + 'Optional["C"]', + self.filters.field_type(self.obj_nested_nested_nested, attr), ) def test_field_type_with_forward_reference(self): attr = AttrFactory.create( - types=AttrTypeFactory.list(1, qname="foo_bar", forward=True) + types=AttrTypeFactory.list(1, qname="b", forward=True) ) self.assertEqual( - 'Optional["Parent.Inner.FooBar"]', - self.filters.field_type(attr, ["Parent", "Inner"]), + 'Optional["A.B"]', + self.filters.field_type(self.obj_nested_nested, attr), ) self.filters.postponed_annotations = True self.filters.union_type = True self.assertEqual( - "None | Parent.Inner.FooBar", - self.filters.field_type(attr, ["Parent", "Inner"]), - ) - - def test_field_type_with_forward_and_circular_reference(self): - attr = AttrFactory.create( - types=AttrTypeFactory.list(1, qname="foo_bar", forward=True, circular=True) - ) - - self.assertEqual( - 'Optional["Parent.Inner"]', - self.filters.field_type(attr, ["Parent", "Inner"]), + "None | A.B", self.filters.field_type(self.obj_nested_nested, attr) ) def test_field_type_with_array_type(self): attr = AttrFactory.create( - types=AttrTypeFactory.list(1, qname="foo_bar", forward=True) + types=AttrTypeFactory.list(1, qname="c", forward=True) ) attr.restrictions.max_occurs = 2 self.assertEqual( - 'List["A.Parent.FooBar"]', - self.filters.field_type(attr, ["A", "Parent"]), + 'List["A.B.C"]', + self.filters.field_type(self.obj, attr), ) self.filters.format.frozen = True - self.assertEqual( - 'Tuple["A.Parent.FooBar", ...]', - self.filters.field_type(attr, ["A", "Parent"]), - ) + self.assertEqual('Tuple["A.B.C", ...]', self.filters.field_type(self.obj, attr)) self.filters.subscriptable_types = True - self.assertEqual( - 'tuple["A.Parent.FooBar", ...]', - self.filters.field_type(attr, ["A", "Parent"]), - ) + self.assertEqual('tuple["A.B.C", ...]', self.filters.field_type(self.obj, attr)) self.filters.format.frozen = False - self.assertEqual( - 'list["A.Parent.FooBar"]', - self.filters.field_type(attr, ["A", "Parent"]), - ) + self.assertEqual('list["A.B.C"]', self.filters.field_type(self.obj, attr)) def test_field_type_with_token_attr(self): attr = AttrFactory.create( types=AttrTypeFactory.list(1, qname="foo_bar"), restrictions=Restrictions(tokens=True), ) - self.assertEqual("List[FooBar]", self.filters.field_type(attr, [])) + self.assertEqual("List[FooBar]", self.filters.field_type(self.obj, attr)) attr.restrictions.max_occurs = 2 - self.assertEqual("List[List[FooBar]]", self.filters.field_type(attr, [])) + self.assertEqual("List[List[FooBar]]", self.filters.field_type(self.obj, attr)) attr.restrictions.max_occurs = 1 self.filters.format.frozen = True - self.assertEqual("Tuple[FooBar, ...]", self.filters.field_type(attr, [])) + self.assertEqual("Tuple[FooBar, ...]", self.filters.field_type(self.obj, attr)) attr.restrictions.max_occurs = 2 self.assertEqual( - "Tuple[Tuple[FooBar, ...], ...]", self.filters.field_type(attr, []) + "Tuple[Tuple[FooBar, ...], ...]", self.filters.field_type(self.obj, attr) ) self.filters.subscriptable_types = True self.assertEqual( - "tuple[tuple[FooBar, ...], ...]", self.filters.field_type(attr, []) + "tuple[tuple[FooBar, ...], ...]", self.filters.field_type(self.obj, attr) ) def test_field_type_with_alias(self): attr = AttrFactory.create( - types=AttrTypeFactory.list( - 1, qname="foo_bar", forward=True, alias="Boss:Life" - ) + types=AttrTypeFactory.list(1, qname="b", forward=True, alias="Boss:Life") ) attr.restrictions.max_occurs = 2 self.assertEqual( - 'List["A.Parent.BossLife"]', - self.filters.field_type(attr, ["A", "Parent"]), + 'List["A.BossLife"]', + self.filters.field_type(self.obj_nested_nested_nested, attr), ) def test_field_type_with_multiple_types(self): attr = AttrFactory.create( types=[ - AttrTypeFactory.create(qname="life", alias="Boss:Life", forward=True), + AttrTypeFactory.create(qname="c", alias="Boss:Life", forward=True), AttrTypeFactory.native(DataType.INT), ] ) attr.restrictions.max_occurs = 2 self.assertEqual( - 'List[Union["A.Parent.BossLife", int]]', - self.filters.field_type(attr, ["A", "Parent"]), + 'List[Union["A.B.BossLife", int]]', + self.filters.field_type(self.obj_nested_nested_nested, attr), ) self.filters.union_type = True self.assertEqual( - 'List["A.Parent.BossLife" | int]', - self.filters.field_type(attr, ["A", "Parent"]), + 'List["A.B.BossLife" | int]', + self.filters.field_type(self.obj_nested_nested_nested, attr), ) self.filters.subscriptable_types = True self.assertEqual( - 'list["A.Parent.BossLife" | int]', - self.filters.field_type(attr, ["A", "Parent"]), + 'list["A.B.BossLife" | int]', + self.filters.field_type(self.obj_nested_nested_nested, attr), ) def test_field_type_with_any_attribute(self): attr = AttrFactory.any_attribute() - self.assertEqual("Dict[str, str]", self.filters.field_type(attr, ["a", "b"])) + self.assertEqual("Dict[str, str]", self.filters.field_type(self.obj, attr)) self.filters.subscriptable_types = True - self.assertEqual("dict[str, str]", self.filters.field_type(attr, ["a", "b"])) + self.assertEqual("dict[str, str]", self.filters.field_type(self.obj, attr)) def test_field_type_with_native_type(self): attr = AttrFactory.create( @@ -723,16 +719,16 @@ def test_field_type_with_native_type(self): ] ) self.assertEqual( - "Optional[Union[int, str]]", self.filters.field_type(attr, ["a", "b"]) + "Optional[Union[int, str]]", self.filters.field_type(self.obj, attr) ) self.filters.union_type = True - self.assertEqual("None | int | str", self.filters.field_type(attr, ["a", "b"])) + self.assertEqual("None | int | str", self.filters.field_type(self.obj, attr)) def test_field_type_with_prohibited_attr(self): attr = AttrFactory.create(restrictions=Restrictions(max_occurs=0)) - self.assertEqual("Any", self.filters.field_type(attr, ["a", "b"])) + self.assertEqual("Any", self.filters.field_type(self.obj, attr)) def test_field_type_with_compound_attr(self): attr = AttrFactory.create( @@ -754,68 +750,75 @@ def test_field_type_with_compound_attr(self): ) expected = "Optional[Union[str, int, List[Decimal]]]" - self.assertEqual(expected, self.filters.field_type(attr, [])) + self.assertEqual(expected, self.filters.field_type(self.obj, attr)) attr.restrictions.max_occurs = 2 expected = "List[Union[str, int, List[Decimal]]]" - self.assertEqual(expected, self.filters.field_type(attr, [])) + self.assertEqual(expected, self.filters.field_type(self.obj, attr)) attr.restrictions.min_occurs = attr.restrictions.max_occurs = 1 self.filters.format.kw_only = True expected = "Union[str, int, List[Decimal]]" - self.assertEqual(expected, self.filters.field_type(attr, [])) + self.assertEqual(expected, self.filters.field_type(self.obj, attr)) def test_choice_type(self): choice = AttrFactory.create(types=[AttrTypeFactory.create("foobar")]) - actual = self.filters.choice_type(choice, ["a", "b"]) + target = ClassFactory.create() + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[Foobar]", actual) def test_choice_type_with_forward_reference(self): choice = AttrFactory.create( types=[AttrTypeFactory.create("foobar", forward=True)] ) - actual = self.filters.choice_type(choice, ["a", "b"]) - self.assertEqual('ForwardRef("A.B.Foobar")', actual) + target = ClassFactory.create(qname="foobar") + parent = ClassFactory.create(qname="a") + parent.inner.append(target) + target.parent = parent + + actual = self.filters.choice_type(parent, choice) + self.assertEqual('ForwardRef("A.Foobar")', actual) def test_choice_type_with_circular_reference(self): - choice = AttrFactory.create( - types=[AttrTypeFactory.create("foobar", circular=True)] - ) - actual = self.filters.choice_type(choice, ["a", "b"]) - self.assertEqual('ForwardRef("Foobar")', actual) + choice = AttrFactory.create(types=[AttrTypeFactory.create("c", circular=True)]) + actual = self.filters.choice_type(self.obj_nested_nested_nested, choice) + self.assertEqual('ForwardRef("C")', actual) self.filters.postponed_annotations = True - actual = self.filters.choice_type(choice, ["a", "b"]) - self.assertEqual('ForwardRef("Foobar")', actual) + actual = self.filters.choice_type(self.obj_nested_nested_nested, choice) + self.assertEqual('ForwardRef("C")', actual) def test_choice_type_with_multiple_types(self): choice = AttrFactory.create(types=[type_str, type_bool]) - actual = self.filters.choice_type(choice, ["a", "b"]) + target = ClassFactory.create() + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[Union[str, bool]]", actual) self.filters.union_type = True - actual = self.filters.choice_type(choice, ["a", "b"]) + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[str | bool]", actual) def test_choice_type_with_list_types_are_ignored(self): choice = AttrFactory.create(types=[type_str, type_bool]) choice.restrictions.max_occurs = 200 - actual = self.filters.choice_type(choice, ["a", "b"]) + target = ClassFactory.create() + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[Union[str, bool]]", actual) def test_choice_type_with_restrictions_tokens_true(self): choice = AttrFactory.create(types=[type_str, type_bool]) choice.restrictions.tokens = True - actual = self.filters.choice_type(choice, ["a", "b"]) + target = ClassFactory.create() + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[List[Union[str, bool]]]", actual) self.filters.format.frozen = True - actual = self.filters.choice_type(choice, ["a", "b"]) + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[Tuple[Union[str, bool], ...]]", actual) self.filters.union_type = True self.filters.subscriptable_types = True - actual = self.filters.choice_type(choice, ["a", "b"]) + actual = self.filters.choice_type(target, choice) self.assertEqual("Type[tuple[str | bool, ...]]", actual) def test_default_imports_with_decimal(self): diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index ed9351214..f8b28444b 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -714,6 +714,16 @@ def has_forward_ref(self) -> bool: return any(inner.has_forward_ref() for inner in self.inner) + def parent_names(self) -> List[str]: + """Return the outer class names.""" + result = [] + target = self.parent + while target is not None: + result.append(target.name) + target = target.parent + + return list(reversed(result)) + @dataclass class Import: diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 92d98ac0a..9452c91b9 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -1,5 +1,6 @@ import sys -from typing import Iterator, List, Optional, Set +from collections import deque +from typing import Deque, Iterator, List, Optional, Set from xsdata.codegen.exceptions import CodegenError from xsdata.codegen.models import ( @@ -495,3 +496,44 @@ def filter_types(cls, types: List[AttrType]) -> List[AttrType]: types.append(AttrType(qname=str(DataType.STRING), native=True)) return types + + @classmethod + def find_nested(cls, target: Class, qname: str) -> Class: + """Find a nested class by qname. + + Breath-first search implementation, that goes + from the current level to bottom before looking + for outer classes. + + Args: + target: The class instance to begin the search + qname: The qualified name of the nested class to find + + Raises: + CodegenException: If the nested class cannot be found. + + Returns: + The nested class instance. + """ + queue: Deque[Class] = deque() + visited: Set[int] = set() + + if target.inner: + queue.extend(target.inner) + elif target.parent: + queue.append(target.parent) + + while len(queue) > 0: + item = queue.popleft() + visited.add(item.ref) + if item.qname == qname: + return item + + for inner in item.inner: + if inner.ref not in visited: + queue.append(inner) + + if len(queue) == 0 and item.parent: + queue.append(item.parent) + + raise CodegenError("Missing inner class", parent=target, qname=qname) diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index dcd869125..acaa29c7f 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -19,6 +19,7 @@ from jinja2 import Environment from xsdata.codegen.models import Attr, AttrType, Class +from xsdata.codegen.utils import ClassUtils from xsdata.formats.converter import converter from xsdata.formats.dataclass.models.elements import XmlType from xsdata.models.config import ( @@ -243,14 +244,14 @@ def apply_substitutions(self, name: str, obj_type: ObjectType) -> str: def field_definition( self, + obj: Class, attr: Attr, - ns_map: Dict, parent_namespace: Optional[str], - parents: List[str], ) -> str: """Return the field definition with any extra metadata.""" + ns_map = obj.ns_map default_value = self.field_default_value(attr, ns_map) - metadata = self.field_metadata(attr, parent_namespace, parents) + metadata = self.field_metadata(obj, attr, parent_namespace) kwargs: Dict[str, Any] = {} if attr.fixed or attr.is_prohibited: @@ -421,9 +422,9 @@ def post_meta_hook(self, obj: Class) -> Optional[str]: def field_metadata( self, + obj: Class, attr: Attr, parent_namespace: Optional[str], - parents: List[str], ) -> Dict: """Return a metadata dictionary for the given attribute.""" if attr.is_prohibited: @@ -432,7 +433,7 @@ def field_metadata( name = namespace = None if not attr.is_nameless and attr.local_name != self.field_name( - attr.name, parents[-1] + attr.name, obj.name ): name = attr.local_name @@ -447,7 +448,7 @@ def field_metadata( "type": attr.xml_type, "namespace": namespace, "mixed": attr.mixed, - "choices": self.field_choices(attr, parent_namespace, parents), + "choices": self.field_choices(obj, attr, parent_namespace), **restrictions, } @@ -458,9 +459,9 @@ def field_metadata( def field_choices( self, + obj: Class, attr: Attr, parent_namespace: Optional[str], - parents: List[str], ) -> Optional[Tuple]: """Return a tuple of field metadata if the attr has choices.""" if not attr.choices: @@ -477,7 +478,7 @@ def field_choices( metadata = { "name": choice.local_name, "wildcard": choice.is_wildcard, - "type": self.choice_type(choice, parents), + "type": self.choice_type(obj, choice), "namespace": namespace, } @@ -742,15 +743,15 @@ def field_default_tokens( return f"lambda: {self.format_metadata(tokens, indent=8)}" - def field_type(self, attr: Attr, parents: List[str]) -> str: + def field_type(self, obj: Class, attr: Attr) -> str: """Generate type hints for the given attr.""" if attr.is_prohibited: return "Any" if attr.tag == Tag.CHOICE: - return self.compound_field_types(attr, parents) + return self.compound_field_types(obj, attr) - result = self._field_type_names(attr, parents, choice=False) + result = self._field_type_names(obj, attr, choice=False) iterable_fmt = self._get_iterable_format() if attr.is_tokens: @@ -772,12 +773,12 @@ def field_type(self, attr: Attr, parents: List[str]) -> str: return result - def compound_field_types(self, attr: Attr, parents: List[str]) -> str: + def compound_field_types(self, obj: Class, attr: Attr) -> str: """Generate type hint for a compound field. Args: + obj: The parent class instance attr: The compound attr instance - parents: A list of the parent class names Returns: The string representation of the type hint. @@ -785,7 +786,7 @@ def compound_field_types(self, attr: Attr, parents: List[str]) -> str: results = [] iterable_fmt = self._get_iterable_format() for choice in attr.choices: - names = self._field_type_names(choice, parents, choice=False) + names = self._field_type_names(obj, choice, choice=False) if choice.is_tokens: names = iterable_fmt.format(names) results.append(names) @@ -800,7 +801,7 @@ def compound_field_types(self, attr: Attr, parents: List[str]) -> str: return result - def choice_type(self, choice: Attr, parents: List[str]) -> str: + def choice_type(self, obj: Class, choice: Attr) -> str: """Generate type hints for the given choice. Choices support a subset of features from normal attributes. @@ -811,13 +812,13 @@ def choice_type(self, choice: Attr, parents: List[str]) -> str: is also ignored. Args: + obj: The parent class instance choice: The choice instance - parents: A list of the parent class names Returns: The string representation of the type hint. """ - result = self._field_type_names(choice, parents, choice=True) + result = self._field_type_names(obj, choice, choice=True) if choice.is_tokens: iterable_fmt = self._get_iterable_format() @@ -830,13 +831,11 @@ def choice_type(self, choice: Attr, parents: List[str]) -> str: def _field_type_names( self, + obj: Class, attr: Attr, - parents: List[str], choice: bool = False, ) -> str: - type_names = [ - self._field_type_name(x, parents, choice=choice) for x in attr.types - ] + type_names = [self._field_type_name(obj, x, choice=choice) for x in attr.types] return self._join_type_names(type_names) def _join_type_names(self, type_names: List[str]) -> str: @@ -850,15 +849,12 @@ def _join_type_names(self, type_names: List[str]) -> str: return f'Union[{", ".join(type_names)}]' def _field_type_name( - self, attr_type: AttrType, parents: List[str], choice: bool = False + self, obj: Class, attr_type: AttrType, choice: bool = False ) -> str: name = self.type_name(attr_type) - - if attr_type.forward and attr_type.circular: - outer_str = ".".join(map(self.class_name, parents)) - name = f'"{outer_str}"' - elif attr_type.forward: - outer_str = ".".join(map(self.class_name, parents)) + if attr_type.forward: + inner = ClassUtils.find_nested(obj, attr_type.qname) + outer_str = ".".join(map(self.class_name, inner.parent_names())) name = f'"{outer_str}.{name}"' elif attr_type.circular: name = f'"{name}"' diff --git a/xsdata/formats/dataclass/templates/class.jinja2 b/xsdata/formats/dataclass/templates/class.jinja2 index 51ba2c4a5..86249b93a 100644 --- a/xsdata/formats/dataclass/templates/class.jinja2 +++ b/xsdata/formats/dataclass/templates/class.jinja2 @@ -3,7 +3,6 @@ {%- include "docstrings." + docstring_name + ".jinja2" -%} {% endset -%} {% set parent_namespace = obj.namespace if obj.namespace is not none else parent_namespace|default(None) -%} -{% set parents = parents|default([obj.name]) -%} {% set class_name = obj.name|class_name -%} {% set class_annotations = obj | class_annotations(class_name) -%} {% set global_type = level == 0 and not obj.local_type -%} @@ -42,15 +41,14 @@ class {{ class_name }}{{"({})".format(base_classes) if base_classes }}: {{ post_meta_output|indent(4, first=True) }} {%- endif -%} {%- for attr in obj.attrs %} - {%- set field_typing = attr|field_type(parents) %} - {%- set field_definition = attr|field_definition(obj.ns_map, parent_namespace, parents) %} + {%- set field_typing = obj|field_type(attr) %} + {%- set field_definition = obj|field_definition(attr, parent_namespace) %} {{ attr.name|field_name(obj.name) }}: {{ field_typing }} = {{ field_definition }} {%- endfor -%} {%- for inner in obj.inner %} {%- set tpl = "enum.jinja2" if inner.is_enumeration else "class.jinja2" -%} - {%- set inner_parents = parents + [inner.name] -%} {%- filter indent(4) -%} - {%- with obj=inner, parents=inner_parents, level=(level + 1) -%} + {%- with obj=inner, level=(level + 1) -%} {% include tpl %} {%- endwith -%} {%- endfilter -%} From 7cacca510f23dcdb874c5c4b7f878cc30c6038c0 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sat, 20 Apr 2024 19:32:27 +0300 Subject: [PATCH 3/4] chore: Remove deprecated find_inner util method --- tests/codegen/test_utils.py | 14 -------------- xsdata/codegen/container.py | 2 +- xsdata/codegen/utils.py | 22 +--------------------- 3 files changed, 2 insertions(+), 36 deletions(-) diff --git a/tests/codegen/test_utils.py b/tests/codegen/test_utils.py index 6dd873a66..f851d7454 100644 --- a/tests/codegen/test_utils.py +++ b/tests/codegen/test_utils.py @@ -214,20 +214,6 @@ def test_copy_inner_class_with_missing_inner(self): with self.assertRaises(CodegenError): ClassUtils.copy_inner_class(source, target, attr_type) - def test_find_inner(self): - obj = ClassFactory.create(qname="{a}parent") - first = ClassFactory.create(qname="{a}a") - second = ClassFactory.create(qname="{c}c") - third = ClassFactory.enumeration(2, qname="{d}d") - obj.inner.extend((first, second, third)) - - with self.assertRaises(CodegenError): - self.assertIsNone(ClassUtils.find_inner(obj, "nope")) - - self.assertEqual(first, ClassUtils.find_inner(obj, "{a}a")) - self.assertEqual(second, ClassUtils.find_inner(obj, "{c}c")) - self.assertEqual(third, ClassUtils.find_inner(obj, "{d}d")) - def test_flatten(self): target = ClassFactory.create( qname="{xsdata}root", attrs=AttrFactory.list(3), inner=ClassFactory.list(2) diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index 4207fee22..ae706abc0 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -148,7 +148,7 @@ def find_inner(self, source: Class, qname: str) -> Class: Raises: CodeGenerationError: If the inner class is not found. """ - inner = ClassUtils.find_inner(source, qname) + inner = ClassUtils.find_nested(source, qname) if inner.status < self.step: self.process_class(inner, self.step) diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 9452c91b9..c623e6bf2 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -196,7 +196,7 @@ def copy_inner_class(cls, source: Class, target: Class, attr_type: AttrType): if not attr_type.forward: return - inner = ClassUtils.find_inner(source, attr_type.qname) + inner = ClassUtils.find_nested(source, attr_type.qname) if inner is target: attr_type.circular = True attr_type.reference = target.ref @@ -210,26 +210,6 @@ def copy_inner_class(cls, source: Class, target: Class, attr_type: AttrType): clone.parent = target target.inner.append(clone) - @classmethod - def find_inner(cls, source: Class, qname: str) -> Class: - """Find an inner class in the source class by its qualified name. - - Args: - source: The parent class instance - qname: The inner class qualified name - - Returns: - The inner class instance - - Raises: - CodeGenerationError: If no inner class matched. - """ - for inner in source.inner: - if inner.qname == qname: - return inner - - raise CodegenError("Missing inner class", parent=source, qname=qname) - @classmethod def find_attr(cls, source: Class, name: str) -> Optional[Attr]: """Find an attr in the source class by its name. From 00fe560b2ff26de2d2b2c71d40ac074046aa5065 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sun, 21 Apr 2024 10:38:15 +0300 Subject: [PATCH 4/4] fix: Avoid recursive error on nested group references --- .../handlers/test_flatten_attribute_groups.py | 84 ++++++------------- .../handlers/flatten_attribute_groups.py | 5 +- xsdata/codegen/mixins.py | 4 + xsdata/codegen/utils.py | 9 +- 4 files changed, 41 insertions(+), 61 deletions(-) diff --git a/tests/codegen/handlers/test_flatten_attribute_groups.py b/tests/codegen/handlers/test_flatten_attribute_groups.py index 93c46b90d..4b4b6af33 100644 --- a/tests/codegen/handlers/test_flatten_attribute_groups.py +++ b/tests/codegen/handlers/test_flatten_attribute_groups.py @@ -1,10 +1,7 @@ -from unittest import mock - from xsdata.codegen.container import ClassContainer from xsdata.codegen.exceptions import CodegenError from xsdata.codegen.handlers import FlattenAttributeGroups -from xsdata.codegen.models import Attr, Status -from xsdata.codegen.utils import ClassUtils +from xsdata.codegen.models import Status from xsdata.models.config import GeneratorConfig from xsdata.models.enums import Tag from xsdata.utils.testing import AttrFactory, ClassFactory, FactoryTestCase @@ -17,65 +14,38 @@ def setUp(self): container = ClassContainer(config=GeneratorConfig()) self.processor = FlattenAttributeGroups(container=container) - @mock.patch.object(Attr, "is_group", new_callable=mock.PropertyMock) - @mock.patch.object(FlattenAttributeGroups, "process_attribute") - def test_process(self, mock_process_attribute, mock_is_group): - mock_is_group.side_effect = [ - True, - False, - True, - True, - False, - False, + def test_process(self): + group = ClassFactory.create(qname="group", tag=Tag.GROUP) + group.attrs = [ + AttrFactory.reference(name="one", qname="inner_one", forward=True), + AttrFactory.reference(name="two", qname="inner_two", forward=True), ] - target = ClassFactory.elements(2) - - self.processor.process(target) - self.assertEqual(6, mock_is_group.call_count) - - mock_process_attribute.assert_has_calls( - [ - mock.call(target, target.attrs[0]), - mock.call(target, target.attrs[0]), - mock.call(target, target.attrs[1]), - ] + inner_one = ClassFactory.create( + qname="inner_one", + attrs=[ + AttrFactory.reference(qname="group", tag=Tag.GROUP), + ], ) - - @mock.patch.object(ClassUtils, "copy_group_attributes") - def test_process_attribute_with_group(self, mock_copy_group_attributes): - complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE) - group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) - group_attr = AttrFactory.attribute_group(name="bar") - target = ClassFactory.create() - target.attrs.append(group_attr) - - self.processor.container.add(complex_bar) - self.processor.container.add(group_bar) - self.processor.container.add(target) - - self.processor.process_attribute(target, group_attr) - mock_copy_group_attributes.assert_called_once_with( - group_bar, target, group_attr + inner_two = inner_one.clone(qname="inner_two") + inner_one.parent = group + inner_two.parent = group + group.inner.extend([inner_one, inner_two]) + target = ClassFactory.create( + attrs=[ + AttrFactory.reference(qname="group", tag=Tag.GROUP), + ] ) + self.processor.container.extend([group, target]) + self.processor.container.process() - @mock.patch.object(ClassUtils, "copy_group_attributes") - def test_process_attribute_with_attribute_group(self, mock_copy_group_attributes): - complex_bar = ClassFactory.create(qname="bar", tag=Tag.COMPLEX_TYPE) - group_bar = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) - group_attr = AttrFactory.attribute_group(name="bar") - target = ClassFactory.create() - target.attrs.append(group_attr) - - self.processor.container.add(complex_bar) - self.processor.container.add(group_bar) - self.processor.container.add(target) + self.assertEqual(["one", "two"], [x.name for x in target.attrs]) + self.assertEqual(["inner_one", "inner_two"], [x.name for x in target.inner]) - self.processor.process_attribute(target, group_attr) - mock_copy_group_attributes.assert_called_once_with( - group_bar, target, group_attr - ) + for inner in target.inner: + self.assertEqual(["one", "two"], [x.name for x in inner.attrs]) + self.assertEqual(0, len(inner.inner)) - def test_process_attribute_with_circular_reference(self): + def test_process_attribute_with_self_reference(self): group_attr = AttrFactory.attribute_group(name="bar") target = ClassFactory.create(qname="bar", tag=Tag.ATTRIBUTE_GROUP) target.attrs.append(group_attr) diff --git a/xsdata/codegen/handlers/flatten_attribute_groups.py b/xsdata/codegen/handlers/flatten_attribute_groups.py index 703aeaf22..cfbbf213e 100644 --- a/xsdata/codegen/handlers/flatten_attribute_groups.py +++ b/xsdata/codegen/handlers/flatten_attribute_groups.py @@ -1,6 +1,6 @@ from xsdata.codegen.exceptions import CodegenError from xsdata.codegen.mixins import RelativeHandlerInterface -from xsdata.codegen.models import Attr, Class +from xsdata.codegen.models import Attr, Class, Status from xsdata.codegen.utils import ClassUtils @@ -51,4 +51,5 @@ def process_attribute(self, target: Class, attr: Attr): if source is target: ClassUtils.remove_attribute(target, attr) else: - ClassUtils.copy_group_attributes(source, target, attr) + is_circular_ref = source.status == Status.UNGROUPING + ClassUtils.copy_group_attributes(source, target, attr, is_circular_ref) diff --git a/xsdata/codegen/mixins.py b/xsdata/codegen/mixins.py index 11dcd5ea3..3f47de3d2 100644 --- a/xsdata/codegen/mixins.py +++ b/xsdata/codegen/mixins.py @@ -24,6 +24,10 @@ def __init__(self, config: GeneratorConfig): def __iter__(self) -> Iterator[Class]: """Yield an iterator for the class map values.""" + @abc.abstractmethod + def process(self): + """Run the processor and filter steps.""" + @abc.abstractmethod def find(self, qname: str, condition: Callable = return_true) -> Optional[Class]: """Find class that matches the given qualified name and condition callable. diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index c623e6bf2..c7438a862 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -115,7 +115,9 @@ def copy_attributes(cls, source: Class, target: Class, extension: Extension): index += 1 @classmethod - def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): + def copy_group_attributes( + cls, source: Class, target: Class, attr: Attr, skip_inner_classes: bool = False + ): """Copy the attrs of the source class to the target class. The attr represents a reference to the source class which is @@ -125,6 +127,8 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): source: The source class instance target: The target class instance attr: The group attr instance + skip_inner_classes: Whether the attr is circular reference, which + means we can skip copying the inner classes. """ index = target.attrs.index(attr) target.attrs.pop(index) @@ -134,7 +138,8 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): target.attrs.insert(index, clone) index += 1 - cls.copy_inner_classes(source, target, clone) + if not skip_inner_classes: + cls.copy_inner_classes(source, target, clone) @classmethod def copy_extensions(cls, source: Class, target: Class, extension: Extension):