Skip to content

Commit

Permalink
Fix flattening restriction base classes with unorderd sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed May 21, 2023
1 parent 5a7d851 commit 608200e
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 46 deletions.
48 changes: 29 additions & 19 deletions tests/codegen/handlers/test_flatten_class_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_process_complex_extension_removes_extension(
self.assertEqual(0, len(target.extensions))
self.assertEqual(1, len(target.attrs))

mock_should_remove_extension.assert_called_once_with(source, target)
mock_should_remove_extension.assert_called_once_with(source, target, extension)
self.assertEqual(0, mock_copy_attributes.call_count)
self.assertEqual(0, extension.type.reference)

Expand Down Expand Up @@ -343,23 +343,46 @@ def test_find_dependency(self):
def test_should_remove_extension(self):
source = ClassFactory.create()
target = ClassFactory.create()
extension = ExtensionFactory.create(tag=Tag.EXTENSION)
callback = self.processor.should_remove_extension

# source is target
self.assertTrue(self.processor.should_remove_extension(source, source))
self.assertFalse(self.processor.should_remove_extension(source, target))
self.assertTrue(callback(source, source, extension))
self.assertFalse(callback(source, target, extension))

# Source is parent class
source.inner.append(target)
self.assertTrue(self.processor.should_remove_extension(target, target))
self.assertTrue(callback(target, target, extension))

# MRO Violation
source.inner.clear()
target.extensions.append(ExtensionFactory.reference("foo"))
target.extensions.append(ExtensionFactory.reference("bar"))
self.assertFalse(self.processor.should_remove_extension(source, target))
self.assertFalse(callback(source, target, extension))

source.extensions.append(ExtensionFactory.reference("bar"))
self.assertTrue(self.processor.should_remove_extension(source, target))
self.assertTrue(callback(source, target, extension))

# Sequential violation
extension.tag = Tag.RESTRICTION
source = ClassFactory.elements(4)
target = source.clone()
self.assertFalse(callback(source, target, extension))

for attr in target.attrs:
attr.restrictions.sequence = 1

target.attrs[3].restrictions.max_occurs = 0

self.assertFalse(callback(source, target, extension))

target.attrs = [
target.attrs[1],
target.attrs[0],
target.attrs[2],
target.attrs[3],
]
self.assertTrue(callback(source, target, extension))

def test_should_flatten_extension(self):
source = ClassFactory.create()
Expand Down Expand Up @@ -388,19 +411,6 @@ def test_should_flatten_extension(self):
target = ClassFactory.elements(1)
self.assertTrue(self.processor.should_flatten_extension(source, target))

# Sequential violation
source = ClassFactory.elements(3)
target = source.clone()
self.assertFalse(self.processor.should_flatten_extension(source, target))

for attr in target.attrs:
attr.restrictions.sequence = 1

self.assertFalse(self.processor.should_flatten_extension(source, target))

target.attrs = [target.attrs[1], target.attrs[0], target.attrs[2]]
self.assertTrue(self.processor.should_flatten_extension(source, target))

def test_replace_attributes_type(self):
extension = ExtensionFactory.create()
target = ClassFactory.elements(2)
Expand Down
26 changes: 14 additions & 12 deletions tests/codegen/mappers/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,14 @@ def test_build_class_extensions(self, mock_children_extensions):
bar_type = AttrTypeFactory.create(qname="bar")
foo_type = AttrTypeFactory.create(qname="foo")

bar = ExtensionFactory.create(bar_type)
double = ExtensionFactory.create(bar_type)
foo = ExtensionFactory.create(foo_type)
bar = ExtensionFactory.create(bar_type, tag=Tag.RESTRICTION)
double = ExtensionFactory.create(bar_type, tag=Tag.RESTRICTION)
foo = ExtensionFactory.create(foo_type, tag=Tag.EXTENSION)

mock_children_extensions.return_value = [bar, double, foo]
self_ext = ExtensionFactory.reference(
qname="{xsdata}something",
tag=Tag.ELEMENT,
restrictions=Restrictions(min_occurs=1, max_occurs=1),
)

Expand Down Expand Up @@ -253,15 +254,16 @@ def test_children_extensions(self):

item = ClassFactory.create(ns_map={"bk": "book"})
children = SchemaMapper.children_extensions(complex_type, item)
expected = list(
map(
ExtensionFactory.create,
[
AttrTypeFactory.create(qname=build_qname("book", "b")),
AttrTypeFactory.create(qname=build_qname("book", "c")),
],
)
)
expected = [
ExtensionFactory.create(
AttrTypeFactory.create(qname=build_qname("book", "b")),
tag=Tag.RESTRICTION,
),
ExtensionFactory.create(
AttrTypeFactory.create(qname=build_qname("book", "c")),
tag=Tag.EXTENSION,
),
]

self.assertIsInstance(children, GeneratorType)
self.assertEqual(expected, list(children))
Expand Down
2 changes: 1 addition & 1 deletion tests/formats/dataclass/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_field_definition(self, mock_field_default_value):
def test_field_definition_with_prohibited_attr(self):
attr = AttrFactory.native(DataType.INT)
attr.restrictions.max_occurs = 0
attr.default = "foo"
attr.default = "1"

result = self.filters.field_definition(attr, {}, None, ["Root"])
expected = (
Expand Down
35 changes: 26 additions & 9 deletions xsdata/codegen/handlers/flatten_class_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def process_complex_extension(cls, source: Class, target: Class, ext: Extension)
extension completely, copy all source attributes to the target
class or leave the extension alone.
"""
if cls.should_remove_extension(source, target):
if cls.should_remove_extension(source, target, ext):
target.extensions.remove(ext)
elif cls.should_flatten_extension(source, target):
ClassUtils.copy_attributes(source, target, ext)
Expand All @@ -177,18 +177,25 @@ def find_dependency(self, attr_type: AttrType) -> Optional[Class]:
return None

@classmethod
def should_remove_extension(cls, source: Class, target: Class) -> bool:
def should_remove_extension(
cls, source: Class, target: Class, ext: Extension
) -> bool:
"""
Return whether the extension should be removed because of some
violation.
Violations:
- Circular Reference
- Forward Reference
- Unordered sequences
- MRO Violation A(B), C(B) and extensions includes A, B, C
"""
# Circular or Forward reference
if source is target or target in source.inner:
if (
source is target
or target in source.inner
or cls.have_unordered_sequences(source, target, ext)
):
return True

# MRO Violation
Expand All @@ -212,31 +219,41 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool:
source.is_simple_type
or target.has_suffix_attr
or (source.has_suffix_attr and target.attrs)
or not cls.validate_sequence_order(source, target)
):
return True

return False

@classmethod
def validate_sequence_order(cls, source: Class, target: Class) -> bool:
def have_unordered_sequences(
cls, source: Class, target: Class, ext: Extension
) -> bool:
"""
Validate sequence attributes are in the same order in the parent class.
Dataclasses fields ordering follows the python mro pattern, the
parent fields are always first and they are updated if the
parent fields are always first, and they are updated if the
subclass is overriding any of them but the overall ordering
doesn't change!
@todo This needs a complete rewrite and most likely it needs to
@todo move way down in the process chain.
"""

if ext.tag == Tag.EXTENSION or source.extensions:
return False

sequence = [
attr.name for attr in target.attrs if attr.restrictions.sequence is not None
attr.name
for attr in target.attrs
if attr.restrictions.sequence is not None and not attr.is_prohibited
]
if len(sequence) > 1:
compare = [attr.name for attr in source.attrs if attr.name in sequence]
if compare and compare != sequence:
return False
return True

return True
return False

@classmethod
def replace_attributes_type(cls, target: Class, extension: Extension):
Expand Down
4 changes: 3 additions & 1 deletion xsdata/codegen/mappers/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def build_mixed_content(cls, target: Class, content: DtdContent):
@classmethod
def build_extension(cls, target: Class, data_type: DataType):
ext_type = AttrType(qname=str(data_type), native=True)
extension = Extension(type=ext_type, restrictions=Restrictions())
extension = Extension(
tag=Tag.EXTENSION, type=ext_type, restrictions=Restrictions()
)
target.extensions.append(extension)

@classmethod
Expand Down
10 changes: 7 additions & 3 deletions xsdata/codegen/mappers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def build_class_extensions(cls, obj: ElementBase, target: Class):

restrictions = obj.get_restrictions()
extensions = [
cls.build_class_extension(target, base, restrictions) for base in obj.bases
cls.build_class_extension(obj.class_name, target, base, restrictions)
for base in obj.bases
]
extensions.extend(cls.children_extensions(obj, target))
target.extensions = collections.unique_sequence(extensions)
Expand Down Expand Up @@ -198,17 +199,20 @@ def children_extensions(
continue

for ext in child.bases:
yield cls.build_class_extension(target, ext, child.get_restrictions())
yield cls.build_class_extension(
child.class_name, target, ext, child.get_restrictions()
)

yield from cls.children_extensions(child, target)

@classmethod
def build_class_extension(
cls, target: Class, name: str, restrictions: Dict
cls, tag: str, target: Class, name: str, restrictions: Dict
) -> Extension:
"""Create an extension for the target class."""
return Extension(
type=cls.build_data_type(target, name),
tag=tag,
restrictions=Restrictions(**restrictions),
)

Expand Down
2 changes: 2 additions & 0 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,12 @@ class Extension:
"""
Model representation of a dataclass base class.
:param tag:
:param type:
:param restrictions:
"""

tag: str
type: AttrType
restrictions: Restrictions = field(hash=False)

Expand Down
8 changes: 7 additions & 1 deletion xsdata/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,30 @@ def service(cls, attributes: int, **kwargs: Any) -> Class:

class ExtensionFactory(Factory):
counter = 65
tags = [Tag.ELEMENT, Tag.EXTENSION, Tag.RESTRICTION]

@classmethod
def create(
cls,
attr_type: Optional[AttrType] = None,
restrictions: Optional[Restrictions] = None,
tag: Optional[str] = None,
**kwargs: Any,
) -> Extension:
return Extension(
tag=tag or random.choice(cls.tags),
type=attr_type or AttrTypeFactory.create(**kwargs),
restrictions=restrictions or Restrictions(),
)

@classmethod
def reference(cls, qname: str, **kwargs: Any) -> Extension:
tag = kwargs.pop("tag", None)
restrictions = kwargs.pop("restrictions", None)
return cls.create(
AttrTypeFactory.create(qname=qname, **kwargs), restrictions=restrictions
AttrTypeFactory.create(qname=qname, **kwargs),
tag=tag,
restrictions=restrictions,
)

@classmethod
Expand Down

0 comments on commit 608200e

Please sign in to comment.