Skip to content

Commit

Permalink
fix: Validate min < max occurs
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Mar 17, 2024
1 parent 353855c commit dee98a2
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 55 deletions.
8 changes: 2 additions & 6 deletions tests/codegen/parsers/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,11 @@ def test_end_schema(
schema.elements.append(Element())
schema.elements.append(Element())

for el in schema.elements:
self.assertEqual(1, el.min_occurs)
self.assertEqual(1, el.max_occurs)

self.parser.end_schema(schema)

for el in schema.elements:
self.assertIsNone(el.min_occurs)
self.assertIsNone(el.max_occurs)
self.assertEqual(1, el.min_occurs)
self.assertEqual(1, el.max_occurs)

self.parser.end_schema(ComplexType())

Expand Down
10 changes: 10 additions & 0 deletions tests/models/xsd/test_all.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sys
from unittest import TestCase

from xsdata.models.xsd import All


class AllTests(TestCase):
def test_normalize_max_occurs(self):
obj = All(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = All(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_get_restrictions(self):
obj = All(min_occurs=1, max_occurs=2)
self.assertEqual({"path": [("a", id(obj), 1, 2)]}, obj.get_restrictions())
10 changes: 10 additions & 0 deletions tests/models/xsd/test_any.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import sys
from unittest import TestCase

from xsdata.models.enums import Namespace, NamespaceType
from xsdata.models.xsd import Any


class AnyTests(TestCase):
def test_normalize_max_occurs(self):
obj = Any(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Any(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_property_is_property(self):
self.assertTrue(Any().is_property)

Expand Down
14 changes: 9 additions & 5 deletions tests/models/xsd/test_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@


class ChoiceTests(TestCase):
def test_normalize_max_occurs(self):
obj = Choice(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Choice(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_get_restrictions(self):
obj = Choice(min_occurs=1, max_occurs=2)
self.assertEqual({"path": [("c", id(obj), 1, 2)]}, obj.get_restrictions())

obj = Choice(max_occurs="unbounded")
self.assertEqual(
{"path": [("c", id(obj), 1, sys.maxsize)]}, obj.get_restrictions()
)
10 changes: 10 additions & 0 deletions tests/models/xsd/test_element.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from unittest import TestCase

from xsdata.codegen.exceptions import CodegenError
Expand All @@ -13,6 +14,15 @@


class ElementTests(TestCase):
def test_normalize_max_occurs(self):
obj = Element(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Element(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_property_is_property(self):
obj = Element()
self.assertTrue(obj)
Expand Down
10 changes: 10 additions & 0 deletions tests/models/xsd/test_group.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sys
from unittest import TestCase

from xsdata.models.xsd import Group


class GroupTests(TestCase):
def test_normalize_max_occurs(self):
obj = Group(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Group(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_property_is_property(self):
obj = Group()
self.assertTrue(obj.is_property)
Expand Down
14 changes: 9 additions & 5 deletions tests/models/xsd/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@


class SequenceTests(TestCase):
def test_normalize_max_occurs(self):
obj = Sequence(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Sequence(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_get_restrictions(self):
obj = Sequence(min_occurs=1, max_occurs=2)
self.assertEqual({"path": [("s", id(obj), 1, 2)]}, obj.get_restrictions())

obj = Sequence(min_occurs=1, max_occurs="unbounded")
self.assertEqual(
{"path": [("s", id(obj), 1, sys.maxsize)]}, obj.get_restrictions()
)
6 changes: 5 additions & 1 deletion xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def asdict(self, types: Optional[List[Type]] = None) -> Dict:
result["min_occurs"] = self.min_occurs
if self.max_occurs is not None and self.max_occurs < sys.maxsize:
result["max_occurs"] = self.max_occurs
elif self.min_occurs == self.max_occurs == 1 and not self.nillable:
elif (
self.min_occurs == self.max_occurs == 1
and not self.nillable
and not self.tokens
):
result["required"] = True

for key, value in asdict(self).items():
Expand Down
14 changes: 0 additions & 14 deletions xsdata/codegen/parsers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def end_schema(self, obj: T):
self.set_schema_namespaces(obj)
self.add_default_imports(obj)
self.resolve_schemas_locations(obj)
self.reset_element_occurs(obj)

def end_attribute(self, obj: T):
"""End attribute element entrypoint.
Expand Down Expand Up @@ -411,16 +410,3 @@ def add_default_imports(cls, obj: xsd.Schema):
xsi_ns = Namespace.XSI.uri
if xsi_ns in obj.ns_map.values() and xsi_ns not in imp_namespaces:
obj.imports.insert(0, xsd.Import(namespace=xsi_ns))

@classmethod
def reset_element_occurs(cls, obj: xsd.Schema):
"""Reset the root elements occurs restrictions.
The root elements don't get those.
Args:
obj: The xsd schema instance
"""
for element in obj.elements:
element.min_occurs = None
element.max_occurs = None
69 changes: 45 additions & 24 deletions xsdata/models/xsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
)


def validate_max_occurs(min_occurs: int, max_occurs: UnionType[str, int]) -> int:
"""Validate max occurs."""
if max_occurs == "unbounded":
max_occurs = sys.maxsize

assert isinstance(max_occurs, int)

return max(max_occurs, min_occurs)


@dataclass(frozen=True)
class Docstring:
"""Docstring model representation.
Expand Down Expand Up @@ -121,7 +131,7 @@ class AnyAttribute(AnnotationBase):
)

def __post_init__(self):
"""Clean the namespace value."""
"""Post initialization validations."""
self.namespace = " ".join(unique_sequence(self.namespace.split()))

@property
Expand Down Expand Up @@ -351,14 +361,15 @@ class Any(AnnotationBase):

namespace: str = attribute(default="##any")
min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
process_contents: ProcessType = attribute(
default=ProcessType.STRICT, name="processContents"
)

def __post_init__(self):
"""Clean the namespace value."""
"""Post initialization validations."""
self.namespace = " ".join(unique_sequence(self.namespace.split()))
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

@property
def is_property(self) -> bool:
Expand Down Expand Up @@ -397,17 +408,19 @@ class All(AnnotationBase):
"""XSD All model representation."""

min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
any: Array[Any] = array_element(name="any")
elements: Array["Element"] = array_element(name="element")
groups: Array["Group"] = array_element(name="group")

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("a", id(self), self.min_occurs, max_occurs)],
"path": [("a", id(self), self.min_occurs, self.max_occurs)],
}


Expand All @@ -416,19 +429,21 @@ class Sequence(AnnotationBase):
"""XSD Sequence model representation."""

min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
elements: Array["Element"] = array_element(name="element")
groups: Array["Group"] = array_element(name="group")
choices: Array["Choice"] = array_element(name="choice")
sequences: Array["Sequence"] = array_element(name="sequence")
any: Array["Any"] = array_element()

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("s", id(self), self.min_occurs, max_occurs)],
"path": [("s", id(self), self.min_occurs, self.max_occurs)],
}


Expand All @@ -437,19 +452,21 @@ class Choice(AnnotationBase):
"""XSD Choice model representation."""

min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
elements: Array["Element"] = array_element(name="element")
groups: Array["Group"] = array_element(name="group")
choices: Array["Choice"] = array_element(name="choice")
sequences: Array[Sequence] = array_element(name="sequence")
any: Array["Any"] = array_element()

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("c", id(self), self.min_occurs, max_occurs)],
"path": [("c", id(self), self.min_occurs, self.max_occurs)],
}


Expand All @@ -460,11 +477,15 @@ class Group(AnnotationBase):
name: Optional[str] = attribute()
ref: str = attribute(default="")
min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
all: Optional[All] = element()
choice: Optional[Choice] = element()
sequence: Optional[Sequence] = element()

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

@property
def is_property(self) -> bool:
"""Specify it is qualified to be a class property."""
Expand All @@ -478,10 +499,8 @@ def attr_types(self) -> Iterator[str]:

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("g", id(self), self.min_occurs, max_occurs)],
"path": [("g", id(self), self.min_occurs, self.max_occurs)],
}


Expand Down Expand Up @@ -862,11 +881,15 @@ class Element(AnnotationBase):
uniques: Array[Unique] = array_element(name="unique")
keys: Array[Key] = array_element(name="key")
keyrefs: Array[Keyref] = array_element(name="keyref")
min_occurs: Optional[int] = attribute(default=1, name="minOccurs")
max_occurs: UnionType[None, int, str] = attribute(default=1, name="maxOccurs")
min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
nillable: bool = attribute(default=False)
abstract: bool = attribute(default=False)

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

@property
def bases(self) -> Iterator[str]:
"""Return an iterator of all the base types."""
Expand Down Expand Up @@ -910,11 +933,9 @@ def substitutions(self) -> Array[str]:

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

restrictions = {
"min_occurs": self.min_occurs,
"max_occurs": max_occurs,
"max_occurs": self.max_occurs,
}

if self.simple_type:
Expand Down

0 comments on commit dee98a2

Please sign in to comment.