From 6eccffb886ca537c4d1b1fa0f59cac637311e43f Mon Sep 17 00:00:00 2001 From: Chris Tsou Date: Sun, 28 Jul 2024 08:49:35 +0300 Subject: [PATCH] feat: Filter element unions against fixed attribute values (#1066) --- .../dataclass/parsers/nodes/test_union.py | 48 ++++++++++++++- tests/formats/dataclass/parsers/test_utils.py | 5 +- .../formats/dataclass/parsers/nodes/union.py | 59 +++++++++++++++---- xsdata/formats/dataclass/parsers/utils.py | 5 +- 4 files changed, 101 insertions(+), 16 deletions(-) diff --git a/tests/formats/dataclass/parsers/nodes/test_union.py b/tests/formats/dataclass/parsers/nodes/test_union.py index 15a44b51b..67f533b94 100644 --- a/tests/formats/dataclass/parsers/nodes/test_union.py +++ b/tests/formats/dataclass/parsers/nodes/test_union.py @@ -1,4 +1,4 @@ -from dataclasses import make_dataclass +from dataclasses import field, make_dataclass from typing import Union from unittest import TestCase @@ -60,6 +60,41 @@ def test_bind_appends_end_event_when_level_not_zero(self): self.assertEqual(0, node.level) self.assertEqual([("end", "bar", "text", "tail")], node.events) + def test_filter_fixed_attrs(self): + a = make_dataclass( + "A", + [("x", int, field(init=False, default=1, metadata={"type": "Attribute"}))], + ) + b = make_dataclass( + "A", + [("x", int, field(init=False, default=2, metadata={"type": "Attribute"}))], + ) + + root = make_dataclass("Root", [("value", Union[a, b, int])]) + meta = self.context.build(root) + var = next(meta.find_children("value")) + node = UnionNode( + meta=meta, + var=var, + position=0, + config=self.config, + context=self.context, + attrs={"x": 2}, + ns_map={}, + ) + self.assertEqual([b], node.candidates) + + node = UnionNode( + meta=meta, + var=var, + position=0, + config=self.config, + context=self.context, + attrs={}, + ns_map={}, + ) + self.assertEqual([a, b, int], node.candidates) + def test_bind_returns_best_matching_object(self): item = make_dataclass( "Item", [("value", str), ("a", int, attribute()), ("b", int, attribute())] @@ -95,8 +130,15 @@ def test_bind_returns_best_matching_object(self): self.assertIsNot(node.attrs, node.events[0][2]) self.assertIs(node.ns_map, node.events[0][3]) - node.events.clear() - node.attrs.clear() + node = UnionNode( + meta=meta, + var=var, + position=0, + config=self.config, + context=self.context, + attrs={}, + ns_map=ns_map, + ) self.assertTrue(node.bind("item", "1", None, objects)) self.assertEqual(1, objects[-1][1]) diff --git a/tests/formats/dataclass/parsers/test_utils.py b/tests/formats/dataclass/parsers/test_utils.py index 618444891..4dfaca43b 100644 --- a/tests/formats/dataclass/parsers/test_utils.py +++ b/tests/formats/dataclass/parsers/test_utils.py @@ -7,7 +7,7 @@ from xsdata.formats.dataclass.context import XmlContext from xsdata.formats.dataclass.parsers.config import ParserConfig from xsdata.formats.dataclass.parsers.utils import ParserUtils -from xsdata.models.enums import Namespace, QNames +from xsdata.models.enums import Namespace, ProcessType, QNames from xsdata.utils.testing import FactoryTestCase, XmlMetaFactory, XmlVarFactory @@ -117,6 +117,9 @@ def test_validate_fixed_value(self): var = XmlVarFactory.create("fixed", default=lambda: float("nan")) ParserUtils.validate_fixed_value(meta, var, float("nan")) + var = XmlVarFactory.create("fixed", default=lambda: ProcessType.LAX) + ParserUtils.validate_fixed_value(meta, var, "lax") + def test_parse_var_with_error(self): meta = XmlMetaFactory.create(clazz=TypeA, qname="foo") var = XmlVarFactory.create("fixed", default="a") diff --git a/xsdata/formats/dataclass/parsers/nodes/union.py b/xsdata/formats/dataclass/parsers/nodes/union.py index 68377336b..6839ad581 100644 --- a/xsdata/formats/dataclass/parsers/nodes/union.py +++ b/xsdata/formats/dataclass/parsers/nodes/union.py @@ -1,7 +1,8 @@ import copy +import functools from contextlib import suppress from dataclasses import replace -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type from xsdata.exceptions import ParserError from xsdata.formats.dataclass.context import XmlContext @@ -40,6 +41,7 @@ class UnionNode(XmlNode): "context", "level", "events", + "candidates", ) def __init__( @@ -60,8 +62,41 @@ def __init__( self.config = config self.context = context self.level = 0 + self.candidates = self.filter_candidates() self.events: List[Tuple[str, str, Any, Any]] = [] + def filter_candidates(self) -> List[Type]: + """Filter union candidates by fixed attributes.""" + candidates = list(self.var.types) + fixed_attribute = functools.partial( + self.filter_fixed_attrs, parent_ns=target_uri(self.var.qname) + ) + + return list(filter(fixed_attribute, candidates)) + + def filter_fixed_attrs(self, candidate: Type, parent_ns: str) -> bool: + """Return whether the node attrs are incompatible with fixed attrs. + + Args: + candidate: The candidate type + parent_ns: The parent namespace + """ + if not self.context.class_type.is_model(candidate): + return not self.attrs + + meta = self.context.build(candidate, parent_ns=parent_ns) + for qname, value in self.attrs.items(): + var = meta.find_attribute(qname) + if not var or var.init: + continue + + try: + ParserUtils.validate_fixed_value(meta, var, value) + except ParserError: + return False + + return True + def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: """Record the event for the child element. @@ -120,29 +155,31 @@ def bind( parent_namespace = target_uri(qname) config = replace(self.config, fail_on_converter_warnings=True) - for clazz in self.var.types: - candidate = None + for candidate in self.candidates: + result = None with suppress(Exception): - if self.context.class_type.is_model(clazz): - self.context.build(clazz, parent_ns=parent_namespace) + if self.context.class_type.is_model(candidate): + self.context.build(candidate, parent_ns=parent_namespace) parser = NodeParser( - config=config, context=self.context, handler=EventsHandler + config=config, + context=self.context, + handler=EventsHandler, ) - candidate = parser.parse(self.events, clazz) + result = parser.parse(self.events, candidate) else: - candidate = ParserUtils.parse_var( + result = ParserUtils.parse_var( meta=self.meta, var=self.var, config=config, value=text, - types=[clazz], + types=[candidate], ns_map=self.ns_map, ) - score = self.context.class_type.score_object(candidate) + score = self.context.class_type.score_object(result) if score > max_score: max_score = score - obj = candidate + obj = result if obj: objects.append((self.var.qname, obj)) diff --git a/xsdata/formats/dataclass/parsers/utils.py b/xsdata/formats/dataclass/parsers/utils.py index 1a7742100..1c3b1ba41 100644 --- a/xsdata/formats/dataclass/parsers/utils.py +++ b/xsdata/formats/dataclass/parsers/utils.py @@ -228,7 +228,7 @@ def validate_fixed_value(cls, meta: XmlMeta, var: XmlVar, value: Any): Special cases - float nans are never equal in python - strings with whitespaces, need trimming - + - comparing raw str values """ default_value = var.default() if callable(var.default) else var.default @@ -244,6 +244,9 @@ def validate_fixed_value(cls, meta: XmlMeta, var: XmlVar, value: Any): ): return + if isinstance(value, str) and not isinstance(default_value, str): + default_value = converter.serialize(default_value, format=var.format) + if default_value != value: raise ParserError( f"Fixed value mismatch {meta.qname}:{var.qname}, `{default_value} != {value}`"