Skip to content

Commit

Permalink
Make recursive apply on native Python data structures public (#187)
Browse files Browse the repository at this point in the history
Properly document the function's API to make it public
  • Loading branch information
nathanpainchaud authored Nov 16, 2023
1 parent f1d641d commit 1fc1b25
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
22 changes: 22 additions & 0 deletions vital/utils/format/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@
from typing import Any, Dict, List, Mapping, Sequence, TypeVar, Union


def apply(obj, func):
"""Applies a function recursively to all elements inside a Python collection composed of the supported types.
References:
- This function was inspired by a similar function from the 'poutyne' framework:
https://github.com/GRAAL-Research/poutyne/blob/aeb78c2b26edaa30663a88522d39a187baeec9cd/poutyne/utils.py#L104-L113
Args:
obj: The Python object to convert.
func: The function to apply.
Returns:
A new Python collection with the same structure as `obj` but where the elements have been applied the function
`func`. Not supported types are left as reference in the new object.
"""
if isinstance(obj, (list, tuple)):
return type(obj)(apply(el, func) for el in obj)
if isinstance(obj, dict):
return {k: apply(el, func) for k, el in obj.items()}
return func(obj)


def prefix(map: Mapping[str, Any], prefix: str, exclude: Union[str, Sequence[str]] = None) -> Dict[str, Any]:
"""Prepends a prefix to the keys of a mapping with string keys.
Expand Down
17 changes: 3 additions & 14 deletions vital/utils/format/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,8 @@

import numpy as np
import torch
from torch.nn.utils.rnn import PackedSequence


def _apply(obj, func):
if isinstance(obj, (list, tuple)):
if isinstance(obj, PackedSequence):
return type(obj)(
*(_apply(getattr(obj, el), func) if el != "batch_sizes" else getattr(obj, el) for el in obj._fields)
)
return type(obj)(_apply(el, func) for el in obj)
if isinstance(obj, dict):
return {k: _apply(el, func) for k, el in obj.items()}
return func(obj)
from vital.utils.format.native import apply


def torch_apply(obj: Union[Tuple, List, Dict], func: Callable) -> Union[Tuple, List, Dict]:
Expand All @@ -36,7 +25,7 @@ def torch_apply(obj: Union[Tuple, List, Dict], func: Callable) -> Union[Tuple, L
def fn(t):
return func(t) if torch.is_tensor(t) else t

return _apply(obj, fn)
return apply(obj, fn)


def torch_to_numpy(obj: Union[Tuple, List, Dict], copy: bool = False) -> Union[Tuple, List, Dict]:
Expand Down Expand Up @@ -114,4 +103,4 @@ def numpy_to_torch(obj: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]:
def fn(a):
return torch.from_numpy(a) if isinstance(a, np.ndarray) else a

return _apply(obj, fn)
return apply(obj, fn)

0 comments on commit 1fc1b25

Please sign in to comment.