Skip to content

Commit

Permalink
fix: Avoid using not-threadsafe arnings.catch_warning
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed May 19, 2024
1 parent c544fbd commit 346afc1
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 129 deletions.
1 change: 0 additions & 1 deletion tests/codegen/handlers/test_disambiguate_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_disambiguate_choice_with_circular_ref(self):

def test_find_ambiguous_choices_ignore_wildcards(self):
"""Wildcards are merged."""

attr = AttrFactory.create()
attr.choices.append(AttrFactory.any())
attr.choices.append(AttrFactory.any())
Expand Down
16 changes: 9 additions & 7 deletions tests/formats/dataclass/parsers/nodes/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.fixtures.artists import Artist
from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.elements import XmlType
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.utils.testing import XmlMetaFactory, XmlVarFactory
Expand All @@ -12,6 +13,7 @@ class PrimitiveNodeTests(TestCase):
def setUp(self):
super().setUp()
self.meta = XmlMetaFactory.create(clazz=Artist)
self.config = ParserConfig()

@mock.patch.object(ParserUtils, "parse_var")
def test_bind(self, mock_parse_var):
Expand All @@ -20,22 +22,22 @@ def test_bind(self, mock_parse_var):
xml_type=XmlType.TEXT, name="foo", types=(int,), format="Nope"
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(self.meta, var, ns_map)
node = PrimitiveNode(self.meta, var, ns_map, self.config)
objects = []

self.assertTrue(node.bind("foo", "13", "Impossible", objects))
self.assertEqual(("foo", 13), objects[-1])

mock_parse_var.assert_called_once_with(
meta=self.meta, var=var, value="13", ns_map=ns_map
meta=self.meta, var=var, config=self.config, value="13", ns_map=ns_map
)

def test_bind_nillable_content(self):
var = XmlVarFactory.create(
xml_type=XmlType.TEXT, name="foo", types=(str,), nillable=False
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(self.meta, var, ns_map)
node = PrimitiveNode(self.meta, var, ns_map, self.config)
objects = []

self.assertTrue(node.bind("foo", None, None, objects))
Expand All @@ -53,7 +55,7 @@ def test_bind_nillable_bytes_content(self):
nillable=False,
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(self.meta, var, ns_map)
node = PrimitiveNode(self.meta, var, ns_map, self.config)
objects = []

self.assertTrue(node.bind("foo", None, None, objects))
Expand All @@ -66,7 +68,7 @@ def test_bind_nillable_bytes_content(self):
def test_bind_mixed_with_tail_content(self):
self.meta.mixed_content = True
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,))
node = PrimitiveNode(self.meta, var, {})
node = PrimitiveNode(self.meta, var, {}, self.config)
objects = []

self.assertTrue(node.bind("foo", "13", "tail", objects))
Expand All @@ -76,15 +78,15 @@ def test_bind_mixed_with_tail_content(self):
def test_bind_mixed_without_tail_content(self):
self.meta.mixed_content = True
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,))
node = PrimitiveNode(self.meta, var, {})
node = PrimitiveNode(self.meta, var, {}, self.config)
objects = []

self.assertTrue(node.bind("foo", "13", "", objects))
self.assertEqual(13, objects[-1][1])

def test_child(self):
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo")
node = PrimitiveNode(self.meta, var, {})
node = PrimitiveNode(self.meta, var, {}, self.config)

with self.assertRaises(XmlContextError):
node.child("foo", {}, {}, 0)
20 changes: 15 additions & 5 deletions tests/formats/dataclass/parsers/nodes/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.fixtures.artists import Artist
from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.nodes import StandardNode
from xsdata.models.enums import DataType
from xsdata.utils.testing import XmlMetaFactory, XmlVarFactory
Expand All @@ -13,34 +14,41 @@ def setUp(self):
super().setUp()
self.meta = XmlMetaFactory.create(clazz=Artist)
self.var = XmlVarFactory.create()
self.config = ParserConfig()

def test_bind_simple(self):
datatype = DataType.INT
node = StandardNode(self.meta, self.var, datatype, {}, False, False)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, False
)
objects = []

self.assertTrue(node.bind("a", "13", None, objects))
self.assertEqual(("a", 13), objects[-1])

def test_bind_derived(self):
datatype = DataType.INT
node = StandardNode(self.meta, self.var, datatype, {}, False, DerivedElement)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, DerivedElement
)
objects = []

self.assertTrue(node.bind("a", "13", None, objects))
self.assertEqual(("a", DerivedElement("a", 13)), objects[-1])

def test_bind_wrapper_type(self):
datatype = DataType.HEX_BINARY
node = StandardNode(self.meta, self.var, datatype, {}, False, DerivedElement)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, DerivedElement
)
objects = []

self.assertTrue(node.bind("a", "13", None, objects))
self.assertEqual(("a", DerivedElement(qname="a", value=b"\x13")), objects[-1])

def test_bind_nillable(self):
datatype = DataType.STRING
node = StandardNode(self.meta, self.var, datatype, {}, True, None)
node = StandardNode(self.meta, self.var, datatype, {}, self.config, True, None)
objects = []

self.assertTrue(node.bind("a", None, None, objects))
Expand All @@ -52,7 +60,9 @@ def test_bind_nillable(self):

def test_child(self):
datatype = DataType.STRING
node = StandardNode(self.meta, self.var, datatype, {}, False, False)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, False
)

with self.assertRaises(XmlContextError):
node.child("foo", {}, {}, 0)
2 changes: 1 addition & 1 deletion tests/formats/dataclass/parsers/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_end(self, mock_assemble):
objects = [("q", "result")]
queue = []
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo")
queue.append(PrimitiveNode(var, {}, False))
queue.append(PrimitiveNode(var, {}, False, parser.config))

self.assertTrue(parser.end(queue, objects, "author", "foobar", None))
self.assertEqual(0, len(queue))
Expand Down
13 changes: 10 additions & 3 deletions tests/formats/dataclass/parsers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xsdata.exceptions import ParserError
from xsdata.formats.converter import ConverterFactory
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.models.enums import Namespace, QNames
from xsdata.utils.testing import FactoryTestCase, XmlMetaFactory, XmlVarFactory
Expand Down Expand Up @@ -116,17 +117,23 @@ def test_validate_fixed_value(self):
var = XmlVarFactory.create("fixed", default=lambda: float("nan"))
ParserUtils.validate_fixed_value(meta, var, float("nan"))

def test_parse_var_with_warnings(self):
def test_parse_var_with_error(self):
meta = XmlMetaFactory.create(clazz=TypeA, qname="foo")
var = XmlVarFactory.create("fixed", default="a")
config = ParserConfig()

with warnings.catch_warnings(record=True) as w:
result = ParserUtils.parse_var(meta, var, "a", types=[int, float])
result = ParserUtils.parse_var(meta, var, config, "a", types=[int, float])

expected = (
"Failed to convert value for `TypeA.fixed`\n"
" `a` is not a valid `int | float`"
)
self.assertEqual("a", result)

self.assertEqual(expected, str(w[-1].message))

config.fail_on_converter_warnings = True
with self.assertRaises(ParserError) as cm:
ParserUtils.parse_var(meta, var, config, "a", types=[int, float])

self.assertEqual(expected, str(cm.exception))
2 changes: 1 addition & 1 deletion tests/formats/dataclass/parsers/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_end(self, mock_emit_event):
queue = []
meta = XmlMetaFactory.create(clazz=Artist)
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(bool,))
queue.append(PrimitiveNode(meta, var, {}))
queue.append(PrimitiveNode(meta, var, {}, self.parser.config))

result = self.parser.end(queue, objects, "enabled", "true", None)
self.assertTrue(result)
Expand Down
14 changes: 5 additions & 9 deletions xsdata/formats/dataclass/parsers/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,11 @@ def parse(
"""
handler = self.handler(clazz=clazz, parser=self)

with warnings.catch_warnings():
if self.config.fail_on_converter_warnings:
warnings.filterwarnings("error", category=ConverterWarning)

try:
ns_map = self.ns_map if ns_map is None else ns_map
result = handler.parse(source, ns_map)
except (ConverterWarning, SyntaxError) as e:
raise ParserError(e)
try:
ns_map = self.ns_map if ns_map is None else ns_map
result = handler.parse(source, ns_map)
except SyntaxError as e:
raise ParserError(e)

if result is not None:
return result
Expand Down
51 changes: 14 additions & 37 deletions xsdata/formats/dataclass/parsers/dict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from dataclasses import dataclass, field
from contextlib import suppress
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Iterable, List, Optional, Type, Union

from typing_extensions import get_args, get_origin

from xsdata.exceptions import ConverterWarning, ParserError
from xsdata.exceptions import ParserError
from xsdata.formats.converter import converter
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.models.elements import XmlMeta, XmlVar
Expand Down Expand Up @@ -41,18 +41,10 @@ def decode(self, data: Union[List, Dict], clazz: Optional[Type[T]] = None) -> T:
An instance of the specified class representing the decoded content.
"""
tp = self.verify_type(clazz, data)
if not isinstance(data, list):
return self.bind_dataclass(data, tp)

with warnings.catch_warnings():
if self.config.fail_on_converter_warnings:
warnings.filterwarnings("error", category=ConverterWarning)

try:
if not isinstance(data, list):
return self.bind_dataclass(data, tp)

return [self.bind_dataclass(obj, tp) for obj in data] # type: ignore
except ConverterWarning as e:
raise ParserError(e)
return [self.bind_dataclass(obj, tp) for obj in data] # type: ignore

def verify_type(self, clazz: Optional[Type[T]], data: Union[Dict, List]) -> Type[T]:
"""Verify the given data matches the given clazz.
Expand Down Expand Up @@ -206,12 +198,18 @@ def bind_best_dataclass(self, data: Dict, classes: Iterable[Type[T]]) -> T:
obj = None
keys = set(data.keys())
max_score = -1.0
config = replace(self.config, fail_on_converter_warnings=True)
decoder = DictDecoder(config=config, context=self.context)

for clazz in classes:
if not self.context.class_type.is_model(clazz):
continue

if self.context.local_names_match(keys, clazz):
candidate = self.bind_optional_dataclass(data, clazz)
candidate = None
with suppress(Exception):
candidate = decoder.bind_dataclass(data, clazz)

score = self.context.class_type.score_object(candidate)
if score > max_score:
max_score = score
Expand All @@ -225,28 +223,6 @@ def bind_best_dataclass(self, data: Dict, classes: Iterable[Type[T]]) -> T:
f"to any of the {[cls.__qualname__ for cls in classes]}"
)

def bind_optional_dataclass(self, data: Dict, clazz: Type[T]) -> Optional[T]:
"""Bind the input data to the given class type.
This is a strict process, if there is any warning the process
returns None. This method is used to test if te data fit into
the class type.
Args:
data: The derived element dictionary
clazz: The target class type to bind the input data
Returns:
An instance of the class type representing the parsed content
or None if there is any warning or error.
"""
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConverterWarning)
return self.bind_dataclass(data, clazz)
except Exception:
return None

def bind_value(
self,
meta: XmlMeta,
Expand Down Expand Up @@ -328,6 +304,7 @@ def bind_text(self, meta: XmlMeta, var: XmlVar, value: Any) -> Any:
return ParserUtils.parse_var(
meta=meta,
var=var,
config=self.config,
value=value,
ns_map=EMPTY_MAP,
)
Expand Down
5 changes: 4 additions & 1 deletion xsdata/formats/dataclass/parsers/nodes/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def bind_attr(self, params: Dict, var: XmlVar, value: Any):
value = ParserUtils.parse_var(
meta=self.meta,
var=var,
config=self.config,
value=value,
ns_map=self.ns_map,
)
Expand Down Expand Up @@ -372,6 +373,7 @@ def bind_text(self, params: Dict, text: Optional[str]) -> bool:
value = ParserUtils.parse_var(
meta=self.meta,
var=var,
config=self.config,
value=text,
ns_map=self.ns_map,
)
Expand Down Expand Up @@ -518,7 +520,7 @@ def build_node(
)

if not var.any_type and not var.is_wildcard:
return nodes.PrimitiveNode(self.meta, var, ns_map)
return nodes.PrimitiveNode(self.meta, var, ns_map, self.config)

datatype = DataType.from_qname(xsi_type) if xsi_type else None
derived = var.is_wildcard
Expand All @@ -528,6 +530,7 @@ def build_node(
var,
datatype,
ns_map,
self.config,
var.nillable,
derived_factory if derived else None,
)
Expand Down
13 changes: 10 additions & 3 deletions xsdata/formats/dataclass/parsers/nodes/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.elements import XmlMeta, XmlVar
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.mixins import XmlNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils

Expand All @@ -13,14 +14,16 @@ class PrimitiveNode(XmlNode):
meta: The parent xml meta instance
var: The xml var instance
ns_map: The element namespace prefix-URI map
config: The parser config instance
"""

__slots__ = "meta", "var", "ns_map"
__slots__ = "meta", "var", "ns_map", "config"

def __init__(self, meta: XmlMeta, var: XmlVar, ns_map: Dict):
def __init__(self, meta: XmlMeta, var: XmlVar, ns_map: Dict, config: ParserConfig):
self.meta = meta
self.var = var
self.ns_map = ns_map
self.config = config

def bind(
self,
Expand All @@ -45,7 +48,11 @@ def bind(
Whether the binding process was successful or not.
"""
obj = ParserUtils.parse_var(
meta=self.meta, var=self.var, value=text, ns_map=self.ns_map
meta=self.meta,
var=self.var,
config=self.config,
value=text,
ns_map=self.ns_map,
)

if obj is None and not self.var.nillable:
Expand Down
Loading

0 comments on commit 346afc1

Please sign in to comment.