Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid using not-threadsafe arnings.catch_warning #1042

Merged
merged 1 commit into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/codegen/handlers/test_disambiguate_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_disambiguate_choice_with_circular_ref(self):

def test_find_ambiguous_choices_ignore_wildcards(self):
"""Wildcards are merged."""

attr = AttrFactory.create()
attr.choices.append(AttrFactory.any())
attr.choices.append(AttrFactory.any())
Expand Down
16 changes: 9 additions & 7 deletions tests/formats/dataclass/parsers/nodes/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.fixtures.artists import Artist
from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.elements import XmlType
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.utils.testing import XmlMetaFactory, XmlVarFactory
Expand All @@ -12,6 +13,7 @@ class PrimitiveNodeTests(TestCase):
def setUp(self):
super().setUp()
self.meta = XmlMetaFactory.create(clazz=Artist)
self.config = ParserConfig()

@mock.patch.object(ParserUtils, "parse_var")
def test_bind(self, mock_parse_var):
Expand All @@ -20,22 +22,22 @@ def test_bind(self, mock_parse_var):
xml_type=XmlType.TEXT, name="foo", types=(int,), format="Nope"
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(self.meta, var, ns_map)
node = PrimitiveNode(self.meta, var, ns_map, self.config)
objects = []

self.assertTrue(node.bind("foo", "13", "Impossible", objects))
self.assertEqual(("foo", 13), objects[-1])

mock_parse_var.assert_called_once_with(
meta=self.meta, var=var, value="13", ns_map=ns_map
meta=self.meta, var=var, config=self.config, value="13", ns_map=ns_map
)

def test_bind_nillable_content(self):
var = XmlVarFactory.create(
xml_type=XmlType.TEXT, name="foo", types=(str,), nillable=False
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(self.meta, var, ns_map)
node = PrimitiveNode(self.meta, var, ns_map, self.config)
objects = []

self.assertTrue(node.bind("foo", None, None, objects))
Expand All @@ -53,7 +55,7 @@ def test_bind_nillable_bytes_content(self):
nillable=False,
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(self.meta, var, ns_map)
node = PrimitiveNode(self.meta, var, ns_map, self.config)
objects = []

self.assertTrue(node.bind("foo", None, None, objects))
Expand All @@ -66,7 +68,7 @@ def test_bind_nillable_bytes_content(self):
def test_bind_mixed_with_tail_content(self):
self.meta.mixed_content = True
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,))
node = PrimitiveNode(self.meta, var, {})
node = PrimitiveNode(self.meta, var, {}, self.config)
objects = []

self.assertTrue(node.bind("foo", "13", "tail", objects))
Expand All @@ -76,15 +78,15 @@ def test_bind_mixed_with_tail_content(self):
def test_bind_mixed_without_tail_content(self):
self.meta.mixed_content = True
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,))
node = PrimitiveNode(self.meta, var, {})
node = PrimitiveNode(self.meta, var, {}, self.config)
objects = []

self.assertTrue(node.bind("foo", "13", "", objects))
self.assertEqual(13, objects[-1][1])

def test_child(self):
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo")
node = PrimitiveNode(self.meta, var, {})
node = PrimitiveNode(self.meta, var, {}, self.config)

with self.assertRaises(XmlContextError):
node.child("foo", {}, {}, 0)
20 changes: 15 additions & 5 deletions tests/formats/dataclass/parsers/nodes/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.fixtures.artists import Artist
from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.nodes import StandardNode
from xsdata.models.enums import DataType
from xsdata.utils.testing import XmlMetaFactory, XmlVarFactory
Expand All @@ -13,34 +14,41 @@ def setUp(self):
super().setUp()
self.meta = XmlMetaFactory.create(clazz=Artist)
self.var = XmlVarFactory.create()
self.config = ParserConfig()

def test_bind_simple(self):
datatype = DataType.INT
node = StandardNode(self.meta, self.var, datatype, {}, False, False)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, False
)
objects = []

self.assertTrue(node.bind("a", "13", None, objects))
self.assertEqual(("a", 13), objects[-1])

def test_bind_derived(self):
datatype = DataType.INT
node = StandardNode(self.meta, self.var, datatype, {}, False, DerivedElement)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, DerivedElement
)
objects = []

self.assertTrue(node.bind("a", "13", None, objects))
self.assertEqual(("a", DerivedElement("a", 13)), objects[-1])

def test_bind_wrapper_type(self):
datatype = DataType.HEX_BINARY
node = StandardNode(self.meta, self.var, datatype, {}, False, DerivedElement)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, DerivedElement
)
objects = []

self.assertTrue(node.bind("a", "13", None, objects))
self.assertEqual(("a", DerivedElement(qname="a", value=b"\x13")), objects[-1])

def test_bind_nillable(self):
datatype = DataType.STRING
node = StandardNode(self.meta, self.var, datatype, {}, True, None)
node = StandardNode(self.meta, self.var, datatype, {}, self.config, True, None)
objects = []

self.assertTrue(node.bind("a", None, None, objects))
Expand All @@ -52,7 +60,9 @@ def test_bind_nillable(self):

def test_child(self):
datatype = DataType.STRING
node = StandardNode(self.meta, self.var, datatype, {}, False, False)
node = StandardNode(
self.meta, self.var, datatype, {}, self.config, False, False
)

with self.assertRaises(XmlContextError):
node.child("foo", {}, {}, 0)
2 changes: 1 addition & 1 deletion tests/formats/dataclass/parsers/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_end(self, mock_assemble):
objects = [("q", "result")]
queue = []
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo")
queue.append(PrimitiveNode(var, {}, False))
queue.append(PrimitiveNode(var, {}, False, parser.config))

self.assertTrue(parser.end(queue, objects, "author", "foobar", None))
self.assertEqual(0, len(queue))
Expand Down
13 changes: 10 additions & 3 deletions tests/formats/dataclass/parsers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xsdata.exceptions import ParserError
from xsdata.formats.converter import ConverterFactory
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.models.enums import Namespace, QNames
from xsdata.utils.testing import FactoryTestCase, XmlMetaFactory, XmlVarFactory
Expand Down Expand Up @@ -116,17 +117,23 @@ def test_validate_fixed_value(self):
var = XmlVarFactory.create("fixed", default=lambda: float("nan"))
ParserUtils.validate_fixed_value(meta, var, float("nan"))

def test_parse_var_with_warnings(self):
def test_parse_var_with_error(self):
meta = XmlMetaFactory.create(clazz=TypeA, qname="foo")
var = XmlVarFactory.create("fixed", default="a")
config = ParserConfig()

with warnings.catch_warnings(record=True) as w:
result = ParserUtils.parse_var(meta, var, "a", types=[int, float])
result = ParserUtils.parse_var(meta, var, config, "a", types=[int, float])

expected = (
"Failed to convert value for `TypeA.fixed`\n"
" `a` is not a valid `int | float`"
)
self.assertEqual("a", result)

self.assertEqual(expected, str(w[-1].message))

config.fail_on_converter_warnings = True
with self.assertRaises(ParserError) as cm:
ParserUtils.parse_var(meta, var, config, "a", types=[int, float])

self.assertEqual(expected, str(cm.exception))
2 changes: 1 addition & 1 deletion tests/formats/dataclass/parsers/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_end(self, mock_emit_event):
queue = []
meta = XmlMetaFactory.create(clazz=Artist)
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(bool,))
queue.append(PrimitiveNode(meta, var, {}))
queue.append(PrimitiveNode(meta, var, {}, self.parser.config))

result = self.parser.end(queue, objects, "enabled", "true", None)
self.assertTrue(result)
Expand Down
14 changes: 5 additions & 9 deletions xsdata/formats/dataclass/parsers/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,11 @@ def parse(
"""
handler = self.handler(clazz=clazz, parser=self)

with warnings.catch_warnings():
if self.config.fail_on_converter_warnings:
warnings.filterwarnings("error", category=ConverterWarning)

try:
ns_map = self.ns_map if ns_map is None else ns_map
result = handler.parse(source, ns_map)
except (ConverterWarning, SyntaxError) as e:
raise ParserError(e)
try:
ns_map = self.ns_map if ns_map is None else ns_map
result = handler.parse(source, ns_map)
except SyntaxError as e:
raise ParserError(e)

if result is not None:
return result
Expand Down
51 changes: 14 additions & 37 deletions xsdata/formats/dataclass/parsers/dict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from dataclasses import dataclass, field
from contextlib import suppress
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Iterable, List, Optional, Type, Union

from typing_extensions import get_args, get_origin

from xsdata.exceptions import ConverterWarning, ParserError
from xsdata.exceptions import ParserError
from xsdata.formats.converter import converter
from xsdata.formats.dataclass.context import XmlContext
from xsdata.formats.dataclass.models.elements import XmlMeta, XmlVar
Expand Down Expand Up @@ -41,18 +41,10 @@ def decode(self, data: Union[List, Dict], clazz: Optional[Type[T]] = None) -> T:
An instance of the specified class representing the decoded content.
"""
tp = self.verify_type(clazz, data)
if not isinstance(data, list):
return self.bind_dataclass(data, tp)

with warnings.catch_warnings():
if self.config.fail_on_converter_warnings:
warnings.filterwarnings("error", category=ConverterWarning)

try:
if not isinstance(data, list):
return self.bind_dataclass(data, tp)

return [self.bind_dataclass(obj, tp) for obj in data] # type: ignore
except ConverterWarning as e:
raise ParserError(e)
return [self.bind_dataclass(obj, tp) for obj in data] # type: ignore

def verify_type(self, clazz: Optional[Type[T]], data: Union[Dict, List]) -> Type[T]:
"""Verify the given data matches the given clazz.
Expand Down Expand Up @@ -206,12 +198,18 @@ def bind_best_dataclass(self, data: Dict, classes: Iterable[Type[T]]) -> T:
obj = None
keys = set(data.keys())
max_score = -1.0
config = replace(self.config, fail_on_converter_warnings=True)
decoder = DictDecoder(config=config, context=self.context)

for clazz in classes:
if not self.context.class_type.is_model(clazz):
continue

if self.context.local_names_match(keys, clazz):
candidate = self.bind_optional_dataclass(data, clazz)
candidate = None
with suppress(Exception):
candidate = decoder.bind_dataclass(data, clazz)

score = self.context.class_type.score_object(candidate)
if score > max_score:
max_score = score
Expand All @@ -225,28 +223,6 @@ def bind_best_dataclass(self, data: Dict, classes: Iterable[Type[T]]) -> T:
f"to any of the {[cls.__qualname__ for cls in classes]}"
)

def bind_optional_dataclass(self, data: Dict, clazz: Type[T]) -> Optional[T]:
"""Bind the input data to the given class type.

This is a strict process, if there is any warning the process
returns None. This method is used to test if te data fit into
the class type.

Args:
data: The derived element dictionary
clazz: The target class type to bind the input data

Returns:
An instance of the class type representing the parsed content
or None if there is any warning or error.
"""
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConverterWarning)
return self.bind_dataclass(data, clazz)
except Exception:
return None

def bind_value(
self,
meta: XmlMeta,
Expand Down Expand Up @@ -328,6 +304,7 @@ def bind_text(self, meta: XmlMeta, var: XmlVar, value: Any) -> Any:
return ParserUtils.parse_var(
meta=meta,
var=var,
config=self.config,
value=value,
ns_map=EMPTY_MAP,
)
Expand Down
5 changes: 4 additions & 1 deletion xsdata/formats/dataclass/parsers/nodes/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def bind_attr(self, params: Dict, var: XmlVar, value: Any):
value = ParserUtils.parse_var(
meta=self.meta,
var=var,
config=self.config,
value=value,
ns_map=self.ns_map,
)
Expand Down Expand Up @@ -372,6 +373,7 @@ def bind_text(self, params: Dict, text: Optional[str]) -> bool:
value = ParserUtils.parse_var(
meta=self.meta,
var=var,
config=self.config,
value=text,
ns_map=self.ns_map,
)
Expand Down Expand Up @@ -518,7 +520,7 @@ def build_node(
)

if not var.any_type and not var.is_wildcard:
return nodes.PrimitiveNode(self.meta, var, ns_map)
return nodes.PrimitiveNode(self.meta, var, ns_map, self.config)

datatype = DataType.from_qname(xsi_type) if xsi_type else None
derived = var.is_wildcard
Expand All @@ -528,6 +530,7 @@ def build_node(
var,
datatype,
ns_map,
self.config,
var.nillable,
derived_factory if derived else None,
)
Expand Down
13 changes: 10 additions & 3 deletions xsdata/formats/dataclass/parsers/nodes/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.elements import XmlMeta, XmlVar
from xsdata.formats.dataclass.parsers.config import ParserConfig
from xsdata.formats.dataclass.parsers.mixins import XmlNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils

Expand All @@ -13,14 +14,16 @@ class PrimitiveNode(XmlNode):
meta: The parent xml meta instance
var: The xml var instance
ns_map: The element namespace prefix-URI map
config: The parser config instance
"""

__slots__ = "meta", "var", "ns_map"
__slots__ = "meta", "var", "ns_map", "config"

def __init__(self, meta: XmlMeta, var: XmlVar, ns_map: Dict):
def __init__(self, meta: XmlMeta, var: XmlVar, ns_map: Dict, config: ParserConfig):
self.meta = meta
self.var = var
self.ns_map = ns_map
self.config = config

def bind(
self,
Expand All @@ -45,7 +48,11 @@ def bind(
Whether the binding process was successful or not.
"""
obj = ParserUtils.parse_var(
meta=self.meta, var=self.var, value=text, ns_map=self.ns_map
meta=self.meta,
var=self.var,
config=self.config,
value=text,
ns_map=self.ns_map,
)

if obj is None and not self.var.nillable:
Expand Down
Loading
Loading