diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index cccd272a9a..c41ec377a9 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -694,7 +694,8 @@ def __getattr__(self, name): raise AttributeError(f"{name} hasn't been set yet") if name not in self._field_names: raise AttributeError( - f"Task {self._task.name} has no {self._attr_type} attribute {name}" + f"Task '{self._task.name}' has no {self._attr_type} attribute '{name}', " + "available: '" + "', '".join(self._field_names) + "'" ) type_ = self._get_type(name) splits = self._get_task_splits() diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 665d79327d..46c085be33 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -1,6 +1,7 @@ import os import itertools import sys +import re import typing as ty from pathlib import Path import tempfile @@ -28,6 +29,17 @@ def lz(tp: ty.Type): return LazyOutField(name="foo", field="boo", type=tp) +def exc_info_matches(exc_info, match, regex=False): + if exc_info.value.__cause__ is not None: + msg = str(exc_info.value.__cause__) + else: + msg = str(exc_info.value) + if regex: + return re.match(".*" + match, msg) + else: + return match in msg + + PathTypes = ty.Union[str, os.PathLike] @@ -36,8 +48,9 @@ def test_type_check_basic1(): def test_type_check_basic2(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(int, float)])(lz(float)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic3(): @@ -45,8 +58,9 @@ def test_type_check_basic3(): def test_type_check_basic4(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(ty.Any, float)])(lz(float)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic5(): @@ -54,8 +68,9 @@ def test_type_check_basic5(): def test_type_check_basic6(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=None, not_coercible=[(float, int)])(lz(float)) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_check_basic7(): @@ -63,9 +78,11 @@ def test_type_check_basic7(): path_coercer(lz(Path)) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: path_coercer(lz(str)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") + def test_type_check_basic8(): TypeParser(Path, coercible=[(PathTypes, PathTypes)])(lz(str)) @@ -74,7 +91,6 @@ def test_type_check_basic8(): def test_type_check_basic9(): file_coercer = TypeParser(File, coercible=[(PathTypes, File)]) - file_coercer(lz(Path)) file_coercer(lz(str)) @@ -82,8 +98,9 @@ def test_type_check_basic9(): def test_type_check_basic10(): impotent_str_coercer = TypeParser(str, coercible=[(PathTypes, File)]) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: impotent_str_coercer(lz(File)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_basic11(): @@ -108,12 +125,13 @@ def test_type_check_basic13(): def test_type_check_basic14(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )(lz(str)) + assert exc_info_matches(exc_info, match="explicitly excluded") def test_type_check_basic15(): @@ -121,10 +139,11 @@ def test_type_check_basic15(): def test_type_check_basic16(): - with pytest.raises( - TypeError, match="Cannot coerce to any of the union types" - ): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File, bool, int])(lz(float)) + assert exc_info_matches( + exc_info, match="Cannot coerce to any of the union types" + ) def test_type_check_basic17(): @@ -160,16 +179,18 @@ def test_type_check_nested7(): def test_type_check_nested7a(): - with pytest.raises(TypeError, match="Wrong number of type arguments"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int])) + assert exc_info_matches(exc_info, "Wrong number of type arguments") def test_type_check_nested8(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( ty.Tuple[int, ...], not_coercible=[(ty.Sequence, ty.Tuple)], )(lz(ty.List[float])) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_check_permit_superclass(): @@ -177,43 +198,51 @@ def test_type_check_permit_superclass(): TypeParser(ty.List[File])(lz(ty.List[Json])) # Permissive super class, as File is superclass of Json TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File])) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File])) + assert exc_info_matches(exc_info, "Cannot coerce") # Fails because Yaml is neither sub or super class of Json - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml])) + assert exc_info_matches(exc_info, "Cannot coerce") def test_type_check_fail1(): - with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float])) + assert exc_info_matches(exc_info, "Wrong number of type arguments in tuple") def test_type_check_fail2(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File])(lz(int)) + assert exc_info_matches(exc_info, "to any of the union types") def test_type_check_fail3(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( lz(ty.Dict[str, int]) ) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_check_fail4(): - with pytest.raises(TypeError, match="Cannot coerce into"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence)(lz(ty.Dict[str, int])) + assert exc_info_matches(exc_info, "Cannot coerce into") def test_type_check_fail5(): - with pytest.raises(TypeError, match=" doesn't match pattern"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int])(lz(int)) + assert exc_info_matches(exc_info, " doesn't match pattern") def test_type_check_fail6(): - with pytest.raises(TypeError, match=" doesn't match pattern"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[ty.Dict[str, str]])(lz(ty.Tuple[int, int, int])) + assert exc_info_matches(exc_info, " doesn't match pattern") def test_type_coercion_basic(): @@ -221,8 +250,9 @@ def test_type_coercion_basic(): def test_type_coercion_basic1(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(float, coercible=[(ty.Any, int)])(1) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic2(): @@ -235,8 +265,9 @@ def test_type_coercion_basic2(): def test_type_coercion_basic3(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser(int, coercible=[(ty.Any, ty.Any)], not_coercible=[(float, int)])(1.0) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_coercion_basic4(): @@ -244,8 +275,9 @@ def test_type_coercion_basic4(): assert path_coercer(Path("/a/path")) == Path("/a/path") - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: path_coercer("/a/path") + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic5(): @@ -277,8 +309,9 @@ def test_type_coercion_basic7(a_file): def test_type_coercion_basic8(a_file): impotent_str_coercer = TypeParser(str, coercible=[(PathTypes, File)]) - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: impotent_str_coercer(File(a_file)) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_basic9(a_file): @@ -302,13 +335,13 @@ def test_type_coercion_basic11(): def test_type_coercion_basic12(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( list, coercible=[(ty.Sequence, ty.Sequence)], not_coercible=[(str, ty.Sequence)], )("a-string") - + assert exc_info_matches(exc_info, "explicitly excluded") assert TypeParser(ty.Union[Path, File, int], coercible=[(ty.Any, ty.Any)])(1.0) == 1 @@ -384,46 +417,53 @@ def test_type_coercion_nested7(): def test_type_coercion_nested8(): - with pytest.raises(TypeError, match="explicitly excluded"): + with pytest.raises(TypeError) as exc_info: TypeParser( ty.Tuple[int, ...], coercible=[(ty.Any, ty.Any)], not_coercible=[(ty.Sequence, ty.Tuple)], )([1.0, 2.0, 3.0]) + assert exc_info_matches(exc_info, "explicitly excluded") def test_type_coercion_fail1(): - with pytest.raises(TypeError, match="Incorrect number of items"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Tuple[int, int, int], coercible=[(ty.Any, ty.Any)])( [1.0, 2.0, 3.0, 4.0] ) + assert exc_info_matches(exc_info, "Incorrect number of items") def test_type_coercion_fail2(): - with pytest.raises(TypeError, match="to any of the union types"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Union[Path, File], coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "to any of the union types") def test_type_coercion_fail3(): - with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( {"a": 1, "b": 2} ) + assert exc_info_matches(exc_info, "doesn't match any of the explicit inclusion") def test_type_coercion_fail4(): - with pytest.raises(TypeError, match="Cannot coerce {'a': 1} into"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.Sequence, coercible=[(ty.Any, ty.Any)])({"a": 1}) + assert exc_info_matches(exc_info, "Cannot coerce {'a': 1} into") def test_type_coercion_fail5(): - with pytest.raises(TypeError, match="as 1 is not iterable"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int], coercible=[(ty.Any, ty.Any)])(1) + assert exc_info_matches(exc_info, "as 1 is not iterable") def test_type_coercion_fail6(): - with pytest.raises(TypeError, match="is not a mapping type"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[ty.Dict[str, str]], coercible=[(ty.Any, ty.Any)])((1, 2, 3)) + assert exc_info_matches(exc_info, "is not a mapping type") def test_type_coercion_realistic(): @@ -446,21 +486,29 @@ def f(x: ty.List[File], y: ty.Dict[str, ty.List[File]]): TypeParser(ty.List[str])(task.lzout.a) # pylint: disable=no-member with pytest.raises( TypeError, - match="Cannot coerce into ", - ): + ) as exc_info: TypeParser(ty.List[int])(task.lzout.a) # pylint: disable=no-member + assert exc_info_matches( + exc_info, + match=r"Cannot coerce into ", + regex=True, + ) - with pytest.raises( - TypeError, match="Cannot coerce 'bad-value' into " - ): + with pytest.raises(TypeError) as exc_info: task.inputs.x = "bad-value" + assert exc_info_matches( + exc_info, match="Cannot coerce 'bad-value' into " + ) def test_check_missing_type_args(): - with pytest.raises(TypeError, match="wasn't declared with type args required"): + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int]).check_type(list) - with pytest.raises(TypeError, match="doesn't match pattern"): + assert exc_info_matches(exc_info, "wasn't declared with type args required") + + with pytest.raises(TypeError) as exc_info: TypeParser(ty.List[int]).check_type(dict) + assert exc_info_matches(exc_info, "doesn't match pattern") def test_matches_type_union(): @@ -545,6 +593,18 @@ def test_contains_type_in_dict(): ) +def test_any_union(): + """Check that the superclass auto-cast matches if any of the union args match instead + of all""" + TypeParser(File, match_any_of_union=True).check_type(ty.Union[ty.List[File], Json]) + + +def test_union_superclass_check_type(): + """Check that the superclass auto-cast matches if any of the union args match instead + of all""" + TypeParser(ty.Union[ty.List[File], Json], superclass_auto_cast=True)(lz(File)) + + def test_type_matches(): assert TypeParser.matches([1, 2, 3], ty.List[int]) assert TypeParser.matches((1, 2, 3), ty.Tuple[int, ...]) @@ -648,7 +708,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): ) ) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: # No cast of generic task output to MyFormatX wf.add( # Generic task other_specific_task( @@ -656,6 +716,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): name="inner", ) ) + assert exc_info_matches(exc_info, "Cannot coerce") wf.add( # Generic task other_specific_task( @@ -664,7 +725,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): ) ) - with pytest.raises(TypeError, match="Cannot coerce"): + with pytest.raises(TypeError) as exc_info: # No cast of generic task output to MyFormatX wf.add( specific_task( @@ -672,6 +733,7 @@ def test_typing_cast(tmp_path, specific_task, other_specific_task): name="exit", ) ) + assert exc_info_matches(exc_info, "Cannot coerce") wf.add( specific_task( diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index ee8e733e44..4b5f2c4d87 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -64,6 +64,8 @@ class TypeParser(ty.Generic[T]): label : str the label to be used to identify the type parser in error messages. Especially useful when TypeParser is used as a converter in attrs.fields + match_any_of_union : bool + match if any of the options in the union are a subclass (but not necessarily all) """ tp: ty.Type[T] @@ -71,6 +73,7 @@ class TypeParser(ty.Generic[T]): not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]] superclass_auto_cast: bool label: str + match_any_of_union: bool COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = ( ( @@ -115,6 +118,7 @@ def __init__( ] = NOT_COERCIBLE_DEFAULT, superclass_auto_cast: bool = False, label: str = "", + match_any_of_union: bool = False, ): def expand_pattern(t): """Recursively expand the type arguments of the target type in nested tuples""" @@ -143,6 +147,7 @@ def expand_pattern(t): self.not_coercible = list(not_coercible) if not_coercible is not None else [] self.pattern = expand_pattern(tp) self.superclass_auto_cast = superclass_auto_cast + self.match_any_of_union = match_any_of_union def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: """Attempts to coerce the object to the specified type, unless the value is @@ -177,9 +182,15 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: # Check whether the type of the lazy field isn't a superclass of # the type to check against, and if so, allow it due to permissive # typing rules. - TypeParser(obj.type).check_type(self.tp) + TypeParser(obj.type, match_any_of_union=True).check_type( + self.tp + ) except TypeError: - raise e + raise TypeError( + f"Incorrect type for lazy field{self.label_str}: {obj.type!r} " + f"is not a subclass or superclass of {self.tp} (and will not " + "be able to be coerced to one that is)" + ) from e else: logger.info( "Connecting lazy field %s to %s%s via permissive typing that " @@ -189,12 +200,22 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]: self.label_str, ) else: - raise e + raise TypeError( + f"Incorrect type for lazy field{self.label_str}: {obj.type!r} " + f"is not a subclass of {self.tp} (and will not be able to be " + "coerced to one that is)" + ) from e coerced = obj # type: ignore elif isinstance(obj, StateArray): coerced = StateArray(self(o) for o in obj) # type: ignore[assignment] else: - coerced = self.coerce(obj) + try: + coerced = self.coerce(obj) + except TypeError as e: + raise TypeError( + f"Incorrect type for field{self.label_str}: {obj!r} is not of type " + f"{self.tp} (and cannot be coerced to it)" + ) from e return coerced def coerce(self, object_: ty.Any) -> T: @@ -398,12 +419,31 @@ def check_basic(tp, target): # Note that we are deliberately more permissive than typical type-checking # here, allowing parents of the target type as well as children, # to avoid users having to cast from loosely typed tasks to strict ones + if self.match_any_of_union and get_origin(tp) is ty.Union: + reasons = [] + tp_args = get_args(tp) + for tp_arg in tp_args: + if self.is_subclass(tp_arg, target): + return + try: + self.check_coercible(tp_arg, target) + except TypeError as e: + reasons.append(e) + else: + return + if reasons: + raise TypeError( + f"Cannot coerce any union args {tp_arg} to {target}" + f"{self.label_str}:\n\n" + + "\n\n".join(f"{a} -> {e}" for a, e in zip(tp_args, reasons)) + ) if not self.is_subclass(tp, target): self.check_coercible(tp, target) def check_union(tp, pattern_args): if get_origin(tp) is ty.Union: - for tp_arg in get_args(tp): + tp_args = get_args(tp) + for tp_arg in tp_args: reasons = [] for pattern_arg in pattern_args: try: @@ -413,11 +453,15 @@ def check_union(tp, pattern_args): else: reasons = None break + if self.match_any_of_union and len(reasons) < len(tp_args): + # Just need one of the union args to match + return if reasons: + determiner = "any" if self.match_any_of_union else "all" raise TypeError( - f"Cannot coerce {tp} to " - f"ty.Union[{', '.join(str(a) for a in pattern_args)}]{self.label_str}, " - f"because {tp_arg} cannot be coerced to any of its args:\n\n" + f"Cannot coerce {tp} to ty.Union[" + f"{', '.join(str(a) for a in pattern_args)}]{self.label_str}, " + f"because {tp_arg} cannot be coerced to {determiner} of its args:\n\n" + "\n\n".join( f"{a} -> {e}" for a, e in zip(pattern_args, reasons) )