From 608200e6da13809119c734d644efd2ca5f5627bc Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sun, 21 May 2023 11:57:13 +0300 Subject: [PATCH] Fix flattening restriction base classes with unorderd sequences --- .../handlers/test_flatten_class_extensions.py | 48 +++++++++++-------- tests/codegen/mappers/test_schema.py | 26 +++++----- tests/formats/dataclass/test_filters.py | 2 +- .../handlers/flatten_class_extensions.py | 35 ++++++++++---- xsdata/codegen/mappers/dtd.py | 4 +- xsdata/codegen/mappers/schema.py | 10 ++-- xsdata/codegen/models.py | 2 + xsdata/utils/testing.py | 8 +++- 8 files changed, 89 insertions(+), 46 deletions(-) diff --git a/tests/codegen/handlers/test_flatten_class_extensions.py b/tests/codegen/handlers/test_flatten_class_extensions.py index 903c05aab..4b2f4b0fe 100644 --- a/tests/codegen/handlers/test_flatten_class_extensions.py +++ b/tests/codegen/handlers/test_flatten_class_extensions.py @@ -293,7 +293,7 @@ def test_process_complex_extension_removes_extension( self.assertEqual(0, len(target.extensions)) self.assertEqual(1, len(target.attrs)) - mock_should_remove_extension.assert_called_once_with(source, target) + mock_should_remove_extension.assert_called_once_with(source, target, extension) self.assertEqual(0, mock_copy_attributes.call_count) self.assertEqual(0, extension.type.reference) @@ -343,23 +343,46 @@ def test_find_dependency(self): def test_should_remove_extension(self): source = ClassFactory.create() target = ClassFactory.create() + extension = ExtensionFactory.create(tag=Tag.EXTENSION) + callback = self.processor.should_remove_extension # source is target - self.assertTrue(self.processor.should_remove_extension(source, source)) - self.assertFalse(self.processor.should_remove_extension(source, target)) + self.assertTrue(callback(source, source, extension)) + self.assertFalse(callback(source, target, extension)) # Source is parent class source.inner.append(target) - self.assertTrue(self.processor.should_remove_extension(target, target)) + self.assertTrue(callback(target, target, extension)) # MRO Violation source.inner.clear() target.extensions.append(ExtensionFactory.reference("foo")) target.extensions.append(ExtensionFactory.reference("bar")) - self.assertFalse(self.processor.should_remove_extension(source, target)) + self.assertFalse(callback(source, target, extension)) source.extensions.append(ExtensionFactory.reference("bar")) - self.assertTrue(self.processor.should_remove_extension(source, target)) + self.assertTrue(callback(source, target, extension)) + + # Sequential violation + extension.tag = Tag.RESTRICTION + source = ClassFactory.elements(4) + target = source.clone() + self.assertFalse(callback(source, target, extension)) + + for attr in target.attrs: + attr.restrictions.sequence = 1 + + target.attrs[3].restrictions.max_occurs = 0 + + self.assertFalse(callback(source, target, extension)) + + target.attrs = [ + target.attrs[1], + target.attrs[0], + target.attrs[2], + target.attrs[3], + ] + self.assertTrue(callback(source, target, extension)) def test_should_flatten_extension(self): source = ClassFactory.create() @@ -388,19 +411,6 @@ def test_should_flatten_extension(self): target = ClassFactory.elements(1) self.assertTrue(self.processor.should_flatten_extension(source, target)) - # Sequential violation - source = ClassFactory.elements(3) - target = source.clone() - self.assertFalse(self.processor.should_flatten_extension(source, target)) - - for attr in target.attrs: - attr.restrictions.sequence = 1 - - self.assertFalse(self.processor.should_flatten_extension(source, target)) - - target.attrs = [target.attrs[1], target.attrs[0], target.attrs[2]] - self.assertTrue(self.processor.should_flatten_extension(source, target)) - def test_replace_attributes_type(self): extension = ExtensionFactory.create() target = ClassFactory.elements(2) diff --git a/tests/codegen/mappers/test_schema.py b/tests/codegen/mappers/test_schema.py index b89ea77b8..439b22ee2 100644 --- a/tests/codegen/mappers/test_schema.py +++ b/tests/codegen/mappers/test_schema.py @@ -180,13 +180,14 @@ def test_build_class_extensions(self, mock_children_extensions): bar_type = AttrTypeFactory.create(qname="bar") foo_type = AttrTypeFactory.create(qname="foo") - bar = ExtensionFactory.create(bar_type) - double = ExtensionFactory.create(bar_type) - foo = ExtensionFactory.create(foo_type) + bar = ExtensionFactory.create(bar_type, tag=Tag.RESTRICTION) + double = ExtensionFactory.create(bar_type, tag=Tag.RESTRICTION) + foo = ExtensionFactory.create(foo_type, tag=Tag.EXTENSION) mock_children_extensions.return_value = [bar, double, foo] self_ext = ExtensionFactory.reference( qname="{xsdata}something", + tag=Tag.ELEMENT, restrictions=Restrictions(min_occurs=1, max_occurs=1), ) @@ -253,15 +254,16 @@ def test_children_extensions(self): item = ClassFactory.create(ns_map={"bk": "book"}) children = SchemaMapper.children_extensions(complex_type, item) - expected = list( - map( - ExtensionFactory.create, - [ - AttrTypeFactory.create(qname=build_qname("book", "b")), - AttrTypeFactory.create(qname=build_qname("book", "c")), - ], - ) - ) + expected = [ + ExtensionFactory.create( + AttrTypeFactory.create(qname=build_qname("book", "b")), + tag=Tag.RESTRICTION, + ), + ExtensionFactory.create( + AttrTypeFactory.create(qname=build_qname("book", "c")), + tag=Tag.EXTENSION, + ), + ] self.assertIsInstance(children, GeneratorType) self.assertEqual(expected, list(children)) diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index e4ea3b91e..e923fb7d0 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -148,7 +148,7 @@ def test_field_definition(self, mock_field_default_value): def test_field_definition_with_prohibited_attr(self): attr = AttrFactory.native(DataType.INT) attr.restrictions.max_occurs = 0 - attr.default = "foo" + attr.default = "1" result = self.filters.field_definition(attr, {}, None, ["Root"]) expected = ( diff --git a/xsdata/codegen/handlers/flatten_class_extensions.py b/xsdata/codegen/handlers/flatten_class_extensions.py index beb37e9d3..4acc7fda4 100644 --- a/xsdata/codegen/handlers/flatten_class_extensions.py +++ b/xsdata/codegen/handlers/flatten_class_extensions.py @@ -151,7 +151,7 @@ def process_complex_extension(cls, source: Class, target: Class, ext: Extension) extension completely, copy all source attributes to the target class or leave the extension alone. """ - if cls.should_remove_extension(source, target): + if cls.should_remove_extension(source, target, ext): target.extensions.remove(ext) elif cls.should_flatten_extension(source, target): ClassUtils.copy_attributes(source, target, ext) @@ -177,7 +177,9 @@ def find_dependency(self, attr_type: AttrType) -> Optional[Class]: return None @classmethod - def should_remove_extension(cls, source: Class, target: Class) -> bool: + def should_remove_extension( + cls, source: Class, target: Class, ext: Extension + ) -> bool: """ Return whether the extension should be removed because of some violation. @@ -185,10 +187,15 @@ def should_remove_extension(cls, source: Class, target: Class) -> bool: Violations: - Circular Reference - Forward Reference + - Unordered sequences - MRO Violation A(B), C(B) and extensions includes A, B, C """ # Circular or Forward reference - if source is target or target in source.inner: + if ( + source is target + or target in source.inner + or cls.have_unordered_sequences(source, target, ext) + ): return True # MRO Violation @@ -212,31 +219,41 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool: source.is_simple_type or target.has_suffix_attr or (source.has_suffix_attr and target.attrs) - or not cls.validate_sequence_order(source, target) ): return True return False @classmethod - def validate_sequence_order(cls, source: Class, target: Class) -> bool: + def have_unordered_sequences( + cls, source: Class, target: Class, ext: Extension + ) -> bool: """ Validate sequence attributes are in the same order in the parent class. Dataclasses fields ordering follows the python mro pattern, the - parent fields are always first and they are updated if the + parent fields are always first, and they are updated if the subclass is overriding any of them but the overall ordering doesn't change! + + @todo This needs a complete rewrite and most likely it needs to + @todo move way down in the process chain. """ + + if ext.tag == Tag.EXTENSION or source.extensions: + return False + sequence = [ - attr.name for attr in target.attrs if attr.restrictions.sequence is not None + attr.name + for attr in target.attrs + if attr.restrictions.sequence is not None and not attr.is_prohibited ] if len(sequence) > 1: compare = [attr.name for attr in source.attrs if attr.name in sequence] if compare and compare != sequence: - return False + return True - return True + return False @classmethod def replace_attributes_type(cls, target: Class, extension: Extension): diff --git a/xsdata/codegen/mappers/dtd.py b/xsdata/codegen/mappers/dtd.py index 0d3c8789a..7a2370ea4 100644 --- a/xsdata/codegen/mappers/dtd.py +++ b/xsdata/codegen/mappers/dtd.py @@ -118,7 +118,9 @@ def build_mixed_content(cls, target: Class, content: DtdContent): @classmethod def build_extension(cls, target: Class, data_type: DataType): ext_type = AttrType(qname=str(data_type), native=True) - extension = Extension(type=ext_type, restrictions=Restrictions()) + extension = Extension( + tag=Tag.EXTENSION, type=ext_type, restrictions=Restrictions() + ) target.extensions.append(extension) @classmethod diff --git a/xsdata/codegen/mappers/schema.py b/xsdata/codegen/mappers/schema.py index e891d9bb5..d1376be5f 100644 --- a/xsdata/codegen/mappers/schema.py +++ b/xsdata/codegen/mappers/schema.py @@ -114,7 +114,8 @@ def build_class_extensions(cls, obj: ElementBase, target: Class): restrictions = obj.get_restrictions() extensions = [ - cls.build_class_extension(target, base, restrictions) for base in obj.bases + cls.build_class_extension(obj.class_name, target, base, restrictions) + for base in obj.bases ] extensions.extend(cls.children_extensions(obj, target)) target.extensions = collections.unique_sequence(extensions) @@ -198,17 +199,20 @@ def children_extensions( continue for ext in child.bases: - yield cls.build_class_extension(target, ext, child.get_restrictions()) + yield cls.build_class_extension( + child.class_name, target, ext, child.get_restrictions() + ) yield from cls.children_extensions(child, target) @classmethod def build_class_extension( - cls, target: Class, name: str, restrictions: Dict + cls, tag: str, target: Class, name: str, restrictions: Dict ) -> Extension: """Create an extension for the target class.""" return Extension( type=cls.build_data_type(target, name), + tag=tag, restrictions=Restrictions(**restrictions), ) diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 663c70218..76d7b1b7b 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -389,10 +389,12 @@ class Extension: """ Model representation of a dataclass base class. + :param tag: :param type: :param restrictions: """ + tag: str type: AttrType restrictions: Restrictions = field(hash=False) diff --git a/xsdata/utils/testing.py b/xsdata/utils/testing.py index 26cd753a7..1f6db409e 100644 --- a/xsdata/utils/testing.py +++ b/xsdata/utils/testing.py @@ -196,24 +196,30 @@ def service(cls, attributes: int, **kwargs: Any) -> Class: class ExtensionFactory(Factory): counter = 65 + tags = [Tag.ELEMENT, Tag.EXTENSION, Tag.RESTRICTION] @classmethod def create( cls, attr_type: Optional[AttrType] = None, restrictions: Optional[Restrictions] = None, + tag: Optional[str] = None, **kwargs: Any, ) -> Extension: return Extension( + tag=tag or random.choice(cls.tags), type=attr_type or AttrTypeFactory.create(**kwargs), restrictions=restrictions or Restrictions(), ) @classmethod def reference(cls, qname: str, **kwargs: Any) -> Extension: + tag = kwargs.pop("tag", None) restrictions = kwargs.pop("restrictions", None) return cls.create( - AttrTypeFactory.create(qname=qname, **kwargs), restrictions=restrictions + AttrTypeFactory.create(qname=qname, **kwargs), + tag=tag, + restrictions=restrictions, ) @classmethod