diff --git a/vital/utils/format/native.py b/vital/utils/format/native.py index 76e53691..36d65cae 100644 --- a/vital/utils/format/native.py +++ b/vital/utils/format/native.py @@ -2,6 +2,28 @@ from typing import Any, Dict, List, Mapping, Sequence, TypeVar, Union +def apply(obj, func): + """Applies a function recursively to all elements inside a Python collection composed of the supported types. + + References: + - This function was inspired by a similar function from the 'poutyne' framework: + https://github.com/GRAAL-Research/poutyne/blob/aeb78c2b26edaa30663a88522d39a187baeec9cd/poutyne/utils.py#L104-L113 + + Args: + obj: The Python object to convert. + func: The function to apply. + + Returns: + A new Python collection with the same structure as `obj` but where the elements have been applied the function + `func`. Not supported types are left as reference in the new object. + """ + if isinstance(obj, (list, tuple)): + return type(obj)(apply(el, func) for el in obj) + if isinstance(obj, dict): + return {k: apply(el, func) for k, el in obj.items()} + return func(obj) + + def prefix(map: Mapping[str, Any], prefix: str, exclude: Union[str, Sequence[str]] = None) -> Dict[str, Any]: """Prepends a prefix to the keys of a mapping with string keys. diff --git a/vital/utils/format/torch.py b/vital/utils/format/torch.py index 5403bb0a..ad4cebd8 100644 --- a/vital/utils/format/torch.py +++ b/vital/utils/format/torch.py @@ -2,19 +2,8 @@ import numpy as np import torch -from torch.nn.utils.rnn import PackedSequence - -def _apply(obj, func): - if isinstance(obj, (list, tuple)): - if isinstance(obj, PackedSequence): - return type(obj)( - *(_apply(getattr(obj, el), func) if el != "batch_sizes" else getattr(obj, el) for el in obj._fields) - ) - return type(obj)(_apply(el, func) for el in obj) - if isinstance(obj, dict): - return {k: _apply(el, func) for k, el in obj.items()} - return func(obj) +from vital.utils.format.native import apply def torch_apply(obj: Union[Tuple, List, Dict], func: Callable) -> Union[Tuple, List, Dict]: @@ -36,7 +25,7 @@ def torch_apply(obj: Union[Tuple, List, Dict], func: Callable) -> Union[Tuple, L def fn(t): return func(t) if torch.is_tensor(t) else t - return _apply(obj, fn) + return apply(obj, fn) def torch_to_numpy(obj: Union[Tuple, List, Dict], copy: bool = False) -> Union[Tuple, List, Dict]: @@ -114,4 +103,4 @@ def numpy_to_torch(obj: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]: def fn(a): return torch.from_numpy(a) if isinstance(a, np.ndarray) else a - return _apply(obj, fn) + return apply(obj, fn)