diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index bbfbd57941..a2e3651779 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -694,7 +694,8 @@ def __getattr__(self, name): raise AttributeError(f"{name} hasn't been set yet") if name not in self._field_names: raise AttributeError( - f"Task {self._task.name} has no {self._attr_type} attribute {name}" + f"Task '{self._task.name}' has no {self._attr_type} attribute '{name}', " + "available: '" + "', '".join(self._field_names) + "'" ) type_ = self._get_type(name) splits = self._get_task_splits() diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index 77a0f690b7..8221751d01 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -124,7 +124,10 @@ def test_lazy_getvale(): lf = LazyIn(task=tn) with pytest.raises(Exception) as excinfo: lf.inp_c - assert str(excinfo.value) == "Task tn has no input attribute inp_c" + assert ( + str(excinfo.value) + == "Task 'tn' has no input attribute 'inp_c', available: 'inp_a', 'inp_b'" + ) def test_input_file_hash_1(tmp_path): diff --git a/pydra/engine/tests/test_workflow.py b/pydra/engine/tests/test_workflow.py index 598021c832..c6aab6544f 100644 --- a/pydra/engine/tests/test_workflow.py +++ b/pydra/engine/tests/test_workflow.py @@ -37,6 +37,7 @@ from ..core import Workflow from ... import mark from ..specs import SpecInfo, BaseSpec, ShellSpec +from pydra.utils import exc_info_matches def test_wf_no_input_spec(): @@ -102,13 +103,15 @@ def test_wf_dict_input_and_output_spec(): wf.inputs.a = "any-string" wf.inputs.b = {"foo": 1, "bar": False} - with pytest.raises(TypeError, match="Cannot coerce 1.0 into "): + with pytest.raises(TypeError) as exc_info: wf.inputs.a = 1.0 - with pytest.raises( - TypeError, - match=("Could not coerce object, 'bad-value', to any of the union types "), - ): + assert exc_info_matches(exc_info, "Cannot coerce 1.0 into ") + + with pytest.raises(TypeError) as exc_info: wf.inputs.b = {"foo": 1, "bar": "bad-value"} + assert exc_info_matches( + exc_info, "Could not coerce object, 'bad-value', to any of the union types" + ) result = wf() assert result.output.a == "any-string" @@ -5002,14 +5005,13 @@ def test_wf_input_output_typing(): output_spec={"alpha": int, "beta": ty.List[int]}, ) - with pytest.raises( - TypeError, match="Cannot coerce into " - ): + with pytest.raises(TypeError) as exc_info: list_mult_sum( scalar=wf.lzin.y, in_list=wf.lzin.y, name="A", ) + exc_info_matches(exc_info, "Cannot coerce into ") wf.add( # Split over workflow input "x" on "scalar" input list_mult_sum( diff --git a/pydra/utils/__init__.py b/pydra/utils/__init__.py index 9008779e27..cfde94dbf8 100644 --- a/pydra/utils/__init__.py +++ b/pydra/utils/__init__.py @@ -1 +1 @@ -from .misc import user_cache_dir, add_exc_note # noqa: F401 +from .misc import user_cache_dir, add_exc_note, exc_info_matches # noqa: F401 diff --git a/pydra/utils/misc.py b/pydra/utils/misc.py index 9a40769c9d..45b6a5c3ba 100644 --- a/pydra/utils/misc.py +++ b/pydra/utils/misc.py @@ -1,4 +1,5 @@ from pathlib import Path +import re import platformdirs from pydra._version import __version__ @@ -31,3 +32,14 @@ def add_exc_note(e: Exception, note: str) -> Exception: else: e.args = (e.args[0] + "\n" + note,) return e + + +def exc_info_matches(exc_info, match, regex=False): + if exc_info.value.__cause__ is not None: + msg = str(exc_info.value.__cause__) + else: + msg = str(exc_info.value) + if regex: + return re.match(".*" + match, msg) + else: + return match in msg diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index b41aefd2a8..f83eedbd8c 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -6,7 +6,7 @@ import tempfile import pytest from pydra import mark -from ...engine.specs import File, LazyOutField +from ...engine.specs import File, LazyOutField, MultiInputObj from ..typing import TypeParser from pydra import Workflow from fileformats.application import Json, Yaml, Xml @@ -21,6 +21,7 @@ MyOtherFormatX, MyHeader, ) +from pydra.utils import exc_info_matches def lz(tp: ty.Type): @@ -36,8 +37,9 @@ def test_type_check_basic1(): def test_type_check_basic2(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(int, float)])(lz(float)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic3(): @@ -45,8 +47,9 @@ def test_type_check_basic3(): def test_type_check_basic4(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(ty.Any, float)])(lz(float)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic5(): @@ -54,8 +57,9 @@ def test_type_check_basic5(): def test_type_check_basic6(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=None, not_coercible=[(float, int)])(lz(float)) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_check_basic7(): @@ -63,18 +67,22 @@ def test_type_check_basic7(): path_coercer(lz(Path)) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: path_coercer(lz(str)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") + def test_type_check_basic8(): TypeParser(Path, coercible=[(PathTypes, PathTypes)])(lz(str)) + + +def test_type_check_basic8a(): TypeParser(str, coercible=[(PathTypes, PathTypes)])(lz(Path)) def test_type_check_basic9(): file_coercer = TypeParser(File, coercible=[(PathTypes, File)]) - file_coercer(lz(Path)) file_coercer(lz(str)) @@ -82,12 +90,16 @@ def test_type_check_basic9(): def test_type_check_basic10(): impotent_str_coercer = TypeParser(str, coercible=[(PathTypes, File)]) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: impotent_str_coercer(lz(File)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic11(): TypeParser(str, coercible=[(PathTypes, PathTypes)])(lz(File)) + + +def test_type_check_basic11a(): TypeParser(File, coercible=[(PathTypes, PathTypes)])(lz(str)) @@ -108,12 +120,13 @@ def test_type_check_basic13(): def test_type_check_basic14(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )(lz(str)) + assert exc_info_matches(exc_info, match="explicitly excluded") def test_type_check_basic15(): @@ -126,16 +139,18 @@ def test_type_check_basic15a(): def test_type_check_basic16(): - with pytest.raises( - TypeError, match="Cannot coerce to any of the union types" - ): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File, bool, int])(lz(float)) + assert exc_info_matches( + exc_info, match="Cannot coerce to any of the union types" + ) @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_check_basic16a(): with pytest.raises( - TypeError, match="Cannot coerce to any of the union types" + TypeError, + match="Incorrect type for lazy field: is not a subclass of", ): TypeParser(Path | File | bool | int)(lz(float)) @@ -173,16 +188,18 @@ def test_type_check_nested7(): def test_type_check_nested7a(): - with pytest.raises(TypeError, match="Wrong number of type arguments"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int])) + assert exc_info_matches(exc_info, "Wrong number of type arguments") def test_type_check_nested8(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( ty.Tuple[int, ...], not_coercible=[(ty.Sequence, ty.Tuple)], )(lz(ty.List[float])) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_check_permit_superclass(): @@ -190,49 +207,60 @@ def test_type_check_permit_superclass(): TypeParser(ty.List[File])(lz(ty.List[Json])) # Permissive super class, as File is superclass of Json TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File])) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File])) + assert exc_info_matches(exc_info, "Cannot coerce") # Fails because Yaml is neither sub or super class of Json - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml])) + assert exc_info_matches(exc_info, "Cannot coerce") def test_type_check_fail1(): - with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float])) + assert exc_info_matches(exc_info, "Wrong number of type arguments in tuple") def test_type_check_fail2(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File])(lz(int)) + assert exc_info_matches(exc_info, "to any of the union types") @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_check_fail2a(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError, match="Incorrect type for lazy field: "): TypeParser(Path | File)(lz(int)) def test_type_check_fail3(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( lz(ty.Dict[str, int]) ) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_fail4(): - with pytest.raises(TypeError, match="Cannot coerce into"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence)(lz(ty.Dict[str, int])) + assert exc_info_matches( + exc_info, + "Cannot coerce typing.Dict[str, int] into ", + ) def test_type_check_fail5(): - with pytest.raises(TypeError, match=" doesn't match pattern"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int])(lz(int)) + assert exc_info_matches(exc_info, " doesn't match pattern") def test_type_check_fail6(): - with pytest.raises(TypeError, match=" doesn't match pattern"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[ty.Dict[str, str]])(lz(ty.Tuple[int, int, int])) + assert exc_info_matches(exc_info, " doesn't match pattern") def test_type_coercion_basic(): @@ -240,8 +268,9 @@ def test_type_coercion_basic(): def test_type_coercion_basic1(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(float, coercible=[(ty.Any, int)])(1) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic2(): @@ -254,8 +283,9 @@ def test_type_coercion_basic2(): def test_type_coercion_basic3(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(ty.Any, ty.Any)], not_coercible=[(float, int)])(1.0) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_coercion_basic4(): @@ -263,8 +293,9 @@ def test_type_coercion_basic4(): assert path_coercer(Path("/a/path")) == Path("/a/path") - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: path_coercer("/a/path") + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic5(): @@ -296,8 +327,9 @@ def test_type_coercion_basic7(a_file): def test_type_coercion_basic8(a_file): impotent_str_coercer = TypeParser(str, coercible=[(PathTypes, File)]) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: impotent_str_coercer(File(a_file)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic9(a_file): @@ -321,25 +353,25 @@ def test_type_coercion_basic11(): def test_type_coercion_basic12(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )("a-string") - + assert exc_info_matches(exc_info, "explicitly excluded") assert TypeParser(ty.Union[Path, File, int], coercible=[(ty.Any, ty.Any)])(1.0) == 1 @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_coercion_basic12a(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )("a-string") - + assert exc_info_matches(exc_info, "explicitly excluded") assert TypeParser(Path | File | int, coercible=[(ty.Any, ty.Any)])(1.0) == 1 @@ -422,52 +454,60 @@ def test_type_coercion_nested7(): def test_type_coercion_nested8(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( ty.Tuple[int, ...], coercible=[(ty.Any, ty.Any)], not_coercible=[(ty.Sequence, ty.Tuple)], )([1.0, 2.0, 3.0]) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_coercion_fail1(): - with pytest.raises(TypeError, match="Incorrect number of items"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[int, int, int], coercible=[(ty.Any, ty.Any)])( [1.0, 2.0, 3.0, 4.0] ) + assert exc_info_matches(exc_info, "Incorrect number of items") def test_type_coercion_fail2(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File], coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "to any of the union types") @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_coercion_fail2a(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(Path | File, coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "to any of the union types") def test_type_coercion_fail3(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( {"a": 1, "b": 2} ) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_fail4(): - with pytest.raises(TypeError, match="Cannot coerce {'a': 1} into"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Any, ty.Any)])({"a": 1}) + assert exc_info_matches(exc_info, "Cannot coerce {'a': 1} into") def test_type_coercion_fail5(): - with pytest.raises(TypeError, match="as 1 is not iterable"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int], coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "as 1 is not iterable") def test_type_coercion_fail6(): - with pytest.raises(TypeError, match="is not a mapping type"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[ty.Dict[str, str]], coercible=[(ty.Any, ty.Any)])((1, 2, 3)) + assert exc_info_matches(exc_info, "is not a mapping type") def test_type_coercion_realistic(): @@ -490,21 +530,29 @@ def f(x: ty.List[File], y: ty.Dict[str, ty.List[File]]): TypeParser(ty.List[str])(task.lzout.a) # pylint: disable=no-member with pytest.raises( TypeError, - match="Cannot coerce into ", - ): + ) as exc_info: TypeParser(ty.List[int])(task.lzout.a) # pylint: disable=no-member + assert exc_info_matches( + exc_info, + match=r"Cannot coerce into ", + regex=True, + ) - with pytest.raises( - TypeError, match="Cannot coerce 'bad-value' into " - ): + with pytest.raises(TypeError) as exc_info: task.inputs.x = "bad-value" + assert exc_info_matches( + exc_info, match="Cannot coerce 'bad-value' into " + ) def test_check_missing_type_args(): - with pytest.raises(TypeError, match="wasn't declared with type args required"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int]).check_type(list) - with pytest.raises(TypeError, match="doesn't match pattern"): + assert exc_info_matches(exc_info, "wasn't declared with type args required") + + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int]).check_type(dict) + assert exc_info_matches(exc_info, "doesn't match pattern") def test_matches_type_union(): @@ -610,6 +658,21 @@ def test_contains_type_in_dict(): ) +def test_any_union(): + """Check that the superclass auto-cast matches if any of the union args match instead + of all""" + # The Json type within the Union matches File as it is a subclass as `match_any_of_union` + # is set to True. Otherwise, all types within the Union would have to match + TypeParser(File, match_any_of_union=True).check_type(ty.Union[ty.List[File], Json]) + + +def test_union_superclass_check_type(): + """Check that the superclass auto-cast matches if any of the union args match instead + of all""" + # In this case, File matches Json due to the `superclass_auto_cast=True` flag being set + TypeParser(ty.Union[ty.List[File], Json], superclass_auto_cast=True)(lz(File)) + + def test_type_matches(): assert TypeParser.matches([1, 2, 3], ty.List[int]) assert TypeParser.matches((1, 2, 3), ty.Tuple[int, ...]) @@ -713,7 +776,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): ) ) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: # No cast of generic task output to MyFormatX wf.add( # Generic task other_specific_task( @@ -721,6 +784,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): name="inner", ) ) + assert exc_info_matches(exc_info, "Cannot coerce") wf.add( # Generic task other_specific_task( @@ -729,7 +793,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): ) ) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: # No cast of generic task output to MyFormatX wf.add( specific_task( @@ -737,6 +801,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): name="exit", ) ) + assert exc_info_matches(exc_info, "Cannot coerce") wf.add( specific_task( @@ -762,20 +827,42 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): assert out_file.header.parent != in_file.header.parent -def test_type_is_subclass1(): - assert TypeParser.is_subclass(ty.Type[File], type) - - -def test_type_is_subclass2(): - assert not TypeParser.is_subclass(ty.Type[File], ty.Type[Json]) - - -def test_type_is_subclass3(): - assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File]) - - -def test_union_is_subclass1(): - assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]) +@pytest.mark.parametrize( + ("sub", "super"), + [ + (ty.Type[File], type), + (ty.Type[Json], ty.Type[File]), + (ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]), + (Json, ty.Union[Json, Yaml]), + (ty.List[int], list), + (None, ty.Union[int, None]), + (ty.Tuple[int, None], ty.Tuple[int, None]), + (None, None), + (None, type(None)), + (type(None), None), + (type(None), type(None)), + (type(None), type(None)), + ], +) +def test_subclass(sub, super): + assert TypeParser.is_subclass(sub, super) + + +@pytest.mark.parametrize( + ("sub", "super"), + [ + (ty.Type[File], ty.Type[Json]), + (ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]), + (ty.Union[Json, Yaml], Json), + (list, ty.List[int]), + (ty.List[float], ty.List[int]), + (None, ty.Union[int, float]), + (None, int), + (int, None), + ], +) +def test_not_subclass(sub, super): + assert not TypeParser.is_subclass(sub, super) @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -788,18 +875,11 @@ def test_union_is_subclass1b(): assert TypeParser.is_subclass(Json | Yaml, ty.Union[Json, Yaml, Xml]) -## Up to here! - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass1c(): assert TypeParser.is_subclass(ty.Union[Json, Yaml], Json | Yaml | Xml) -def test_union_is_subclass2(): - assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass2a(): assert not TypeParser.is_subclass(Json | Yaml | Xml, Json | Yaml) @@ -815,86 +895,26 @@ def test_union_is_subclass2c(): assert not TypeParser.is_subclass(Json | Yaml | Xml, ty.Union[Json, Yaml]) -def test_union_is_subclass3(): - assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass3a(): assert TypeParser.is_subclass(Json, Json | Yaml) -def test_union_is_subclass4(): - assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_union_is_subclass4a(): assert not TypeParser.is_subclass(Json | Yaml, Json) -def test_generic_is_subclass1(): - assert TypeParser.is_subclass(ty.List[int], list) - - -def test_generic_is_subclass2(): - assert not TypeParser.is_subclass(list, ty.List[int]) - - -def test_generic_is_subclass3(): - assert not TypeParser.is_subclass(ty.List[float], ty.List[int]) - - -def test_none_is_subclass1(): - assert TypeParser.is_subclass(None, ty.Union[int, None]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_none_is_subclass1a(): assert TypeParser.is_subclass(None, int | None) -def test_none_is_subclass2(): - assert not TypeParser.is_subclass(None, ty.Union[int, float]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_none_is_subclass2a(): assert not TypeParser.is_subclass(None, int | float) -def test_none_is_subclass3(): - assert TypeParser.is_subclass(ty.Tuple[int, None], ty.Tuple[int, None]) - - -def test_none_is_subclass4(): - assert TypeParser.is_subclass(None, None) - - -def test_none_is_subclass5(): - assert not TypeParser.is_subclass(None, int) - - -def test_none_is_subclass6(): - assert not TypeParser.is_subclass(int, None) - - -def test_none_is_subclass7(): - assert TypeParser.is_subclass(None, type(None)) - - -def test_none_is_subclass8(): - assert TypeParser.is_subclass(type(None), None) - - -def test_none_is_subclass9(): - assert TypeParser.is_subclass(type(None), type(None)) - - -def test_none_is_subclass10(): - assert TypeParser.is_subclass(type(None), type(None)) - - @pytest.mark.skipif( sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9" ) @@ -924,40 +944,33 @@ class B(A): assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int]) -def test_type_is_instance1(): - assert TypeParser.is_instance(File, ty.Type[File]) - - -def test_type_is_instance2(): - assert not TypeParser.is_instance(File, ty.Type[Json]) - - -def test_type_is_instance3(): - assert TypeParser.is_instance(Json, ty.Type[File]) - - -def test_type_is_instance4(): - assert TypeParser.is_instance(Json, type) - - -def test_type_is_instance5(): - assert TypeParser.is_instance(None, None) - - -def test_type_is_instance6(): - assert TypeParser.is_instance(None, type(None)) - - -def test_type_is_instance7(): - assert not TypeParser.is_instance(None, int) - - -def test_type_is_instance8(): - assert not TypeParser.is_instance(1, None) - - -def test_type_is_instance9(): - assert TypeParser.is_instance(None, ty.Union[int, None]) +@pytest.mark.parametrize( + ("tp", "obj"), + [ + (File, ty.Type[File]), + (Json, ty.Type[File]), + (Json, type), + (None, None), + (None, type(None)), + (None, ty.Union[int, None]), + (1, ty.Union[int, None]), + ], +) +def test_type_is_instance(tp, obj): + assert TypeParser.is_instance(tp, obj) + + +@pytest.mark.parametrize( + ("tp", "obj"), + [ + (File, ty.Type[Json]), + (None, int), + (1, None), + (None, ty.Union[int, str]), + ], +) +def test_type_is_not_instance(tp, obj): + assert not TypeParser.is_instance(tp, obj) @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") @@ -965,19 +978,57 @@ def test_type_is_instance9a(): assert TypeParser.is_instance(None, int | None) -def test_type_is_instance10(): - assert TypeParser.is_instance(1, ty.Union[int, None]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_is_instance10a(): assert TypeParser.is_instance(1, int | None) -def test_type_is_instance11(): - assert not TypeParser.is_instance(None, ty.Union[int, str]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") def test_type_is_instance11a(): assert not TypeParser.is_instance(None, int | str) + + +@pytest.mark.parametrize( + ("typ", "obj", "result"), + [ + (MultiInputObj[str], "a", ["a"]), + (MultiInputObj[str], ["a"], ["a"]), + (MultiInputObj[ty.List[str]], ["a"], [["a"]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ["a"], [["a"]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], [["a"]], [["a"]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], [1], [1]), + ], +) +def test_multi_input_obj_coerce(typ, obj, result): + assert TypeParser(typ)(obj) == result + + +def test_multi_input_obj_coerce4a(): + with pytest.raises(TypeError): + TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([[1]]) + + +@pytest.mark.parametrize( + ("reference", "to_be_checked"), + [ + (MultiInputObj[str], str), + (MultiInputObj[str], ty.List[str]), + (MultiInputObj[ty.List[str]], ty.List[str]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[str]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[ty.List[str]]), + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[int]), + ], +) +def test_multi_input_obj_check_type(reference, to_be_checked): + TypeParser(reference)(lz(to_be_checked)) + + +@pytest.mark.parametrize( + ("reference", "to_be_checked"), + [ + (MultiInputObj[ty.Union[int, ty.List[str]]], ty.List[ty.List[int]]), + ], +) +def test_multi_input_obj_check_type_fail(reference, to_be_checked): + with pytest.raises(TypeError): + TypeParser(reference)(lz(to_be_checked)) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index c765b1339c..e40f928047 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -2,6 +2,7 @@ import inspect from pathlib import Path import os +from copy import copy import sys import types import typing as ty @@ -13,6 +14,7 @@ MultiInputObj, MultiOutputObj, ) +from ..utils import add_exc_note from fileformats import field try: @@ -70,6 +72,8 @@ class TypeParser(ty.Generic[T]): label : str the label to be used to identify the type parser in error messages. Especially useful when TypeParser is used as a converter in attrs.fields + match_any_of_union : bool + match if any of the options in the union are a subclass (but not necessarily all) """ tp: ty.Type[T] @@ -77,6 +81,7 @@ class TypeParser(ty.Generic[T]): not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] superclass_auto_cast: bool label: str + match_any_of_union: bool COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = ( ( @@ -121,6 +126,7 @@ def __init__( ] = NOT_COERCIBLE_DEFAULT, superclass_auto_cast: bool = False, label: str = "", + match_any_of_union: bool = False, ): def expand_pattern(t): """Recursively expand the type arguments of the target type in nested tuples""" @@ -151,6 +157,7 @@ def expand_pattern(t): self.not_coercible = list(not_coercible) if not_coercible is not None else [] self.pattern = expand_pattern(tp) self.superclass_auto_cast = superclass_auto_cast + self.match_any_of_union = match_any_of_union def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: """Attempts to coerce the object to the specified type, unless the value is @@ -185,9 +192,15 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: # Check whether the type of the lazy field isn't a superclass of # the type to check against, and if so, allow it due to permissive # typing rules. - TypeParser(obj.type).check_type(self.tp) + TypeParser(obj.type, match_any_of_union=True).check_type( + self.tp + ) except TypeError: - raise e + raise TypeError( + f"Incorrect type for lazy field{self.label_str}: {obj.type!r} " + f"is not a subclass or superclass of {self.tp} (and will not " + "be able to be coerced to one that is)" + ) from e else: logger.info( "Connecting lazy field %s to %s%s via permissive typing that " @@ -197,12 +210,22 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: self.label_str, ) else: - raise e + raise TypeError( + f"Incorrect type for lazy field{self.label_str}: {obj.type!r} " + f"is not a subclass of {self.tp} (and will not be able to be " + "coerced to one that is)" + ) from e coerced = obj # type: ignore elif isinstance(obj, StateArray): coerced = StateArray(self(o) for o in obj) # type: ignore[assignment] else: - coerced = self.coerce(obj) + try: + coerced = self.coerce(obj) + except TypeError as e: + raise TypeError( + f"Incorrect type for field{self.label_str}: {obj!r} is not of type " + f"{self.tp} (and cannot be coerced to it)" + ) from e return coerced def coerce(self, object_: ty.Any) -> T: @@ -345,7 +368,26 @@ def coerce_obj(obj, type_): f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}" ) from e - return expand_and_coerce(object_, self.pattern) + try: + return expand_and_coerce(object_, self.pattern) + except TypeError as e: + # Special handling for MultiInputObjects (which are annoying) + if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj: + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + return [inner_type_parser.coerce(object_)] + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) + raise e + else: + raise e def check_type(self, type_: ty.Type[ty.Any]): """Checks the given type to see whether it matches or is a subtype of the @@ -392,7 +434,7 @@ def expand_and_check(tp, pattern: ty.Union[type, tuple]): f"{self.pattern}{self.label_str}" ) tp_args = get_args(tp) - self.check_coercible(tp_origin, pattern_origin) + self.check_type_coercible(tp_origin, pattern_origin) if issubclass(pattern_origin, ty.Mapping): return check_mapping(tp_args, pattern_args) if issubclass(pattern_origin, tuple): @@ -406,12 +448,31 @@ def check_basic(tp, target): # Note that we are deliberately more permissive than typical type-checking # here, allowing parents of the target type as well as children, # to avoid users having to cast from loosely typed tasks to strict ones + if self.match_any_of_union and get_origin(tp) is ty.Union: + reasons = [] + tp_args = get_args(tp) + for tp_arg in tp_args: + if self.is_subclass(tp_arg, target): + return + try: + self.check_coercible(tp_arg, target) + except TypeError as e: + reasons.append(e) + else: + return + if reasons: + raise TypeError( + f"Cannot coerce any union args {tp_arg} to {target}" + f"{self.label_str}:\n\n" + + "\n\n".join(f"{a} -> {e}" for a, e in zip(tp_args, reasons)) + ) if not self.is_subclass(tp, target): - self.check_coercible(tp, target) + self.check_type_coercible(tp, target) def check_union(tp, pattern_args): if get_origin(tp) in UNION_TYPES: - for tp_arg in get_args(tp): + tp_args = get_args(tp) + for tp_arg in tp_args: reasons = [] for pattern_arg in pattern_args: try: @@ -421,11 +482,15 @@ def check_union(tp, pattern_args): else: reasons = None break + if self.match_any_of_union and len(reasons) < len(tp_args): + # Just need one of the union args to match + return if reasons: + determiner = "any" if self.match_any_of_union else "all" raise TypeError( - f"Cannot coerce {tp} to " - f"ty.Union[{', '.join(str(a) for a in pattern_args)}]{self.label_str}, " - f"because {tp_arg} cannot be coerced to any of its args:\n\n" + f"Cannot coerce {tp} to ty.Union[" + f"{', '.join(str(a) for a in pattern_args)}]{self.label_str}, " + f"because {tp_arg} cannot be coerced to {determiner} of its args:\n\n" + "\n\n".join( f"{a} -> {e}" for a, e in zip(pattern_args, reasons) ) @@ -482,19 +547,59 @@ def check_sequence(tp_args, pattern_args): for arg in tp_args: expand_and_check(arg, pattern_args[0]) - return expand_and_check(type_, self.pattern) + try: + return expand_and_check(type_, self.pattern) + except TypeError as e: + # Special handling for MultiInputObjects (which are annoying) + if not isinstance(self.pattern, tuple) or self.pattern[0] != MultiInputObj: + raise e + # Attempt to coerce the object into arg type of the MultiInputObj first, + # and if that fails, try to coerce it into a list of the arg type + inner_type_parser = copy(self) + inner_type_parser.pattern = self.pattern[1][0] + try: + inner_type_parser.check_type(type_) + except TypeError: + add_exc_note( + e, + "Also failed to coerce to the arg-type of the MultiInputObj " + f"({self.pattern[1][0]})", + ) + raise e + + def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]): + """Checks whether the source object is coercible to the target type given the coercion + rules defined in the `coercible` and `not_coercible` attrs + + Parameters + ---------- + source : object + the object to be coerced + target : type or typing.Any + the target type for the object to be coerced to - def check_coercible( - self, source: ty.Union[object, type], target: ty.Union[type, ty.Any] + Raises + ------ + TypeError + If the object cannot be coerced into the target type depending on the explicit + inclusions and exclusions set in the `coercible` and `not_coercible` member attrs + """ + self.check_type_coercible(type(source), target, source_repr=repr(source)) + + def check_type_coercible( + self, + source: ty.Union[type, ty.Any], + target: ty.Union[type, ty.Any], + source_repr: ty.Optional[str] = None, ): - """Checks whether the source object or type is coercible to the target type + """Checks whether the source type is coercible to the target type given the coercion rules defined in the `coercible` and `not_coercible` attrs Parameters ---------- - source : object or type - source object or type to be coerced - target : type or ty.Any + source : type or typing.Any + source type to be coerced + target : type or typing.Any target type for the source to be coerced to Raises @@ -504,10 +609,12 @@ def check_coercible( explicit inclusions and exclusions set in the `coercible` and `not_coercible` member attrs """ + if source_repr is None: + source_repr = repr(source) # Short-circuit the basic cases where the source and target are the same if source is target: return - if self.superclass_auto_cast and self.is_subclass(target, type(source)): + if self.superclass_auto_cast and self.is_subclass(target, source): logger.info( "Attempting to coerce %s into %s due to super-to-sub class coercion " "being permitted", @@ -519,13 +626,11 @@ def check_coercible( if source_origin is not None: source = source_origin - source_check = self.is_subclass if inspect.isclass(source) else self.is_instance - def matches_criteria(criteria): return [ (src, tgt) for src, tgt in criteria - if source_check(source, src) and self.is_subclass(target, tgt) + if self.is_subclass(source, src) and self.is_subclass(target, tgt) ] def type_name(t): @@ -536,7 +641,7 @@ def type_name(t): if not matches_criteria(self.coercible): raise TypeError( - f"Cannot coerce {repr(source)} into {target}{self.label_str} as the " + f"Cannot coerce {source_repr} into {target}{self.label_str} as the " "coercion doesn't match any of the explicit inclusion criteria: " + ", ".join( f"{type_name(s)} -> {type_name(t)}" for s, t in self.coercible @@ -545,7 +650,7 @@ def type_name(t): matches_not_coercible = matches_criteria(self.not_coercible) if matches_not_coercible: raise TypeError( - f"Cannot coerce {repr(source)} into {target}{self.label_str} as it is explicitly " + f"Cannot coerce {source_repr} into {target}{self.label_str} as it is explicitly " "excluded by the following coercion criteria: " + ", ".join( f"{type_name(s)} -> {type_name(t)}" @@ -639,7 +744,7 @@ def is_instance( if inspect.isclass(obj): return candidate is type if issubtype(type(obj), candidate) or ( - type(obj) is dict and candidate is ty.Mapping + type(obj) is dict and candidate is ty.Mapping # noqa: E721 ): return True else: