diff --git a/tests/test_torchsnooper.py b/tests/test_torchsnooper.py index 4a69c1b..73b1bfe 100644 --- a/tests/test_torchsnooper.py +++ b/tests/test_torchsnooper.py @@ -165,6 +165,28 @@ def my_function(): ) +def test_recursive_containers(): + string_io = io.StringIO() + + @torchsnooper.snoop(string_io) + def my_function(): + return [{'key': torch.zeros(5, 6, 7)}] + + my_function() + + output = string_io.getvalue() + print(output) + assert_output( + output, + ( + CallEntry(), + LineEntry(), + ReturnEntry(), + ReturnValueEntry("[{'key': tensor<(5, 6, 7), float32, cpu>}]"), + ) + ) + + def test_return_types(): string_io = io.StringIO() diff --git a/torchsnooper/__init__.py b/torchsnooper/__init__.py index 346f61b..d5b8a7c 100644 --- a/torchsnooper/__init__.py +++ b/torchsnooper/__init__.py @@ -1,5 +1,6 @@ import torch import pysnooper +import pysnooper.utils import warnings from pkg_resources import get_distribution, DistributionNotFound @@ -60,11 +61,9 @@ def __call__(self, tensor): class TorchSnooper(pysnooper.tracer.Tracer): def __init__(self, *args, tensor_format=default_format, **kwargs): - custom_repr = (self.condition, self.compute_repr) - if 'custom_repr' in kwargs: - kwargs['custom_repr'] = (custom_repr, *kwargs['custom_repr']) - else: - kwargs['custom_repr'] = (custom_repr,) + self.orig_custom_repr = kwargs['custom_repr'] if 'custom_repr' in kwargs else () + custom_repr = (lambda x: True, self.compute_repr) + kwargs['custom_repr'] = (custom_repr,) super(TorchSnooper, self).__init__(*args, **kwargs) self.tensor_format = tensor_format @@ -93,63 +92,33 @@ def return_types_repr(self, x): return 'gels(solution=' + self.tensor_format(x.solution) + ', QR=' + self.tensor_format(x.QR) + ')' warnings.warn('Unknown return_types encountered, open a bug report!') - def is_list_of_tensors(self, x): - if not isinstance(x, list): - return False - return all([torch.is_tensor(i) for i in x]) - - def list_of_tensors_repr(self, x): - content = '' - for i in x: - if content != '': - content += ', ' - content += self.tensor_format(i) - return '[' + content + ']' - - def is_tuple_of_tensors(self, x): - if not isinstance(x, tuple): - return False - return all([torch.is_tensor(i) for i in x]) - - def tuple_of_tensors_repr(self, x): - content = '' - for i in x: - if content != '': - content += ', ' - content += self.tensor_format(i) - if len(x) == 1: - content += ',' - return '(' + content + ')' - - def is_dict_of_tensors(self, x): - if not isinstance(x, dict): - return False - return all([torch.is_tensor(i) for i in x.values()]) - - def dict_of_tensors_repr(self, x): - content = '' - for k, v in x.items(): - if content != '': - content += ', ' - content += repr(k) + ': ' + self.tensor_format(v) - return '{' + content + '}' - - def condition(self, x): - return torch.is_tensor(x) or self.is_return_types(x) or \ - self.is_list_of_tensors(x) or self.is_tuple_of_tensors(x) or self.is_dict_of_tensors(x) - def compute_repr(self, x): + orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr) if torch.is_tensor(x): return self.tensor_format(x) elif self.is_return_types(x): return self.return_types_repr(x) - elif self.is_list_of_tensors(x): - return self.list_of_tensors_repr(x) - elif self.is_tuple_of_tensors(x): - return self.tuple_of_tensors_repr(x) - elif self.is_dict_of_tensors(x): - return self.dict_of_tensors_repr(x) - raise RuntimeError('Control flow should not reach here, open a bug report!') + elif orig_repr_func is not repr: + return orig_repr_func(x) + elif isinstance(x, (list, tuple)): + content = '' + for i in x: + if content != '': + content += ', ' + content += self.compute_repr(i) + if isinstance(x, tuple) and len(x) == 1: + content += ',' + if isinstance(x, tuple): + return '(' + content + ')' + return '[' + content + ']' + elif isinstance(x, dict): + content = '' + for k, v in x.items(): + if content != '': + content += ', ' + content += self.compute_repr(k) + ': ' + self.compute_repr(v) + return '{' + content + '}' + return repr(x) snoop = TorchSnooper