Skip to content

Commit

Permalink
feat: Filter element unions against fixed attribute values (#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra authored Jul 28, 2024
1 parent cce0a16 commit 6eccffb
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 16 deletions.
48 changes: 45 additions & 3 deletions tests/formats/dataclass/parsers/nodes/test_union.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import make_dataclass
from dataclasses import field, make_dataclass
from typing import Union
from unittest import TestCase

Expand Down Expand Up @@ -60,6 +60,41 @@ def test_bind_appends_end_event_when_level_not_zero(self):
self.assertEqual(0, node.level)
self.assertEqual([("end", "bar", "text", "tail")], node.events)

def test_filter_fixed_attrs(self):
a = make_dataclass(
"A",
[("x", int, field(init=False, default=1, metadata={"type": "Attribute"}))],
)
b = make_dataclass(
"A",
[("x", int, field(init=False, default=2, metadata={"type": "Attribute"}))],
)

root = make_dataclass("Root", [("value", Union[a, b, int])])
meta = self.context.build(root)
var = next(meta.find_children("value"))
node = UnionNode(
meta=meta,
var=var,
position=0,
config=self.config,
context=self.context,
attrs={"x": 2},
ns_map={},
)
self.assertEqual([b], node.candidates)

node = UnionNode(
meta=meta,
var=var,
position=0,
config=self.config,
context=self.context,
attrs={},
ns_map={},
)
self.assertEqual([a, b, int], node.candidates)

def test_bind_returns_best_matching_object(self):
item = make_dataclass(
"Item", [("value", str), ("a", int, attribute()), ("b", int, attribute())]
Expand Down Expand Up @@ -95,8 +130,15 @@ def test_bind_returns_best_matching_object(self):
self.assertIsNot(node.attrs, node.events[0][2])
self.assertIs(node.ns_map, node.events[0][3])

node.events.clear()
node.attrs.clear()
node = UnionNode(
meta=meta,
var=var,
position=0,
config=self.config,
context=self.context,
attrs={},
ns_map=ns_map,
)
self.assertTrue(node.bind("item", "1", None, objects))
self.assertEqual(1, objects[-1][1])

Expand Down
5 changes: 4 additions & 1 deletion tests/formats/dataclass/parsers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.models.enums import Namespace, QNames
from xsdata.models.enums import Namespace, ProcessType, QNames
from xsdata.utils.testing import FactoryTestCase, XmlMetaFactory, XmlVarFactory


Expand Down Expand Up @@ -117,6 +117,9 @@ def test_validate_fixed_value(self):
var = XmlVarFactory.create("fixed", default=lambda: float("nan"))
ParserUtils.validate_fixed_value(meta, var, float("nan"))

var = XmlVarFactory.create("fixed", default=lambda: ProcessType.LAX)
ParserUtils.validate_fixed_value(meta, var, "lax")

def test_parse_var_with_error(self):
meta = XmlMetaFactory.create(clazz=TypeA, qname="foo")
var = XmlVarFactory.create("fixed", default="a")
Expand Down
59 changes: 48 additions & 11 deletions xsdata/formats/dataclass/parsers/nodes/union.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import functools
from contextlib import suppress
from dataclasses import replace
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

from xsdata.exceptions import ParserError
from xsdata.formats.dataclass.context import XmlContext
Expand Down Expand Up @@ -40,6 +41,7 @@ class UnionNode(XmlNode):
"context",
"level",
"events",
"candidates",
)

def __init__(
Expand All @@ -60,8 +62,41 @@ def __init__(
self.config = config
self.context = context
self.level = 0
self.candidates = self.filter_candidates()
self.events: List[Tuple[str, str, Any, Any]] = []

def filter_candidates(self) -> List[Type]:
"""Filter union candidates by fixed attributes."""
candidates = list(self.var.types)
fixed_attribute = functools.partial(
self.filter_fixed_attrs, parent_ns=target_uri(self.var.qname)
)

return list(filter(fixed_attribute, candidates))

def filter_fixed_attrs(self, candidate: Type, parent_ns: str) -> bool:
"""Return whether the node attrs are incompatible with fixed attrs.
Args:
candidate: The candidate type
parent_ns: The parent namespace
"""
if not self.context.class_type.is_model(candidate):
return not self.attrs

meta = self.context.build(candidate, parent_ns=parent_ns)
for qname, value in self.attrs.items():
var = meta.find_attribute(qname)
if not var or var.init:
continue

try:
ParserUtils.validate_fixed_value(meta, var, value)
except ParserError:
return False

return True

def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode:
"""Record the event for the child element.
Expand Down Expand Up @@ -120,29 +155,31 @@ def bind(
parent_namespace = target_uri(qname)
config = replace(self.config, fail_on_converter_warnings=True)

for clazz in self.var.types:
candidate = None
for candidate in self.candidates:
result = None
with suppress(Exception):
if self.context.class_type.is_model(clazz):
self.context.build(clazz, parent_ns=parent_namespace)
if self.context.class_type.is_model(candidate):
self.context.build(candidate, parent_ns=parent_namespace)
parser = NodeParser(
config=config, context=self.context, handler=EventsHandler
config=config,
context=self.context,
handler=EventsHandler,
)
candidate = parser.parse(self.events, clazz)
result = parser.parse(self.events, candidate)
else:
candidate = ParserUtils.parse_var(
result = ParserUtils.parse_var(
meta=self.meta,
var=self.var,
config=config,
value=text,
types=[clazz],
types=[candidate],
ns_map=self.ns_map,
)

score = self.context.class_type.score_object(candidate)
score = self.context.class_type.score_object(result)
if score > max_score:
max_score = score
obj = candidate
obj = result

if obj:
objects.append((self.var.qname, obj))
Expand Down
5 changes: 4 additions & 1 deletion xsdata/formats/dataclass/parsers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def validate_fixed_value(cls, meta: XmlMeta, var: XmlVar, value: Any):
Special cases
- float nans are never equal in python
- strings with whitespaces, need trimming
- comparing raw str values
"""

default_value = var.default() if callable(var.default) else var.default
Expand All @@ -244,6 +244,9 @@ def validate_fixed_value(cls, meta: XmlMeta, var: XmlVar, value: Any):
):
return

if isinstance(value, str) and not isinstance(default_value, str):
default_value = converter.serialize(default_value, format=var.format)

if default_value != value:
raise ParserError(
f"Fixed value mismatch {meta.qname}:{var.qname}, `{default_value} != {value}`"
Expand Down

0 comments on commit 6eccffb

Please sign in to comment.