Skip to content

Commit

Permalink
Recursively repr containers (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jun 28, 2019
1 parent a5a8d90 commit 208f3b6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 57 deletions.
22 changes: 22 additions & 0 deletions tests/test_torchsnooper.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,28 @@ def my_function():
)


def test_recursive_containers():
string_io = io.StringIO()

@torchsnooper.snoop(string_io)
def my_function():
return [{'key': torch.zeros(5, 6, 7)}]

my_function()

output = string_io.getvalue()
print(output)
assert_output(
output,
(
CallEntry(),
LineEntry(),
ReturnEntry(),
ReturnValueEntry("[{'key': tensor<(5, 6, 7), float32, cpu>}]"),
)
)


def test_return_types():
string_io = io.StringIO()

Expand Down
83 changes: 26 additions & 57 deletions torchsnooper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import pysnooper
import pysnooper.utils
import warnings
from pkg_resources import get_distribution, DistributionNotFound

Expand Down Expand Up @@ -60,11 +61,9 @@ def __call__(self, tensor):
class TorchSnooper(pysnooper.tracer.Tracer):

def __init__(self, *args, tensor_format=default_format, **kwargs):
custom_repr = (self.condition, self.compute_repr)
if 'custom_repr' in kwargs:
kwargs['custom_repr'] = (custom_repr, *kwargs['custom_repr'])
else:
kwargs['custom_repr'] = (custom_repr,)
self.orig_custom_repr = kwargs['custom_repr'] if 'custom_repr' in kwargs else ()
custom_repr = (lambda x: True, self.compute_repr)
kwargs['custom_repr'] = (custom_repr,)
super(TorchSnooper, self).__init__(*args, **kwargs)
self.tensor_format = tensor_format

Expand Down Expand Up @@ -93,63 +92,33 @@ def return_types_repr(self, x):
return 'gels(solution=' + self.tensor_format(x.solution) + ', QR=' + self.tensor_format(x.QR) + ')'
warnings.warn('Unknown return_types encountered, open a bug report!')

def is_list_of_tensors(self, x):
if not isinstance(x, list):
return False
return all([torch.is_tensor(i) for i in x])

def list_of_tensors_repr(self, x):
content = ''
for i in x:
if content != '':
content += ', '
content += self.tensor_format(i)
return '[' + content + ']'

def is_tuple_of_tensors(self, x):
if not isinstance(x, tuple):
return False
return all([torch.is_tensor(i) for i in x])

def tuple_of_tensors_repr(self, x):
content = ''
for i in x:
if content != '':
content += ', '
content += self.tensor_format(i)
if len(x) == 1:
content += ','
return '(' + content + ')'

def is_dict_of_tensors(self, x):
if not isinstance(x, dict):
return False
return all([torch.is_tensor(i) for i in x.values()])

def dict_of_tensors_repr(self, x):
content = ''
for k, v in x.items():
if content != '':
content += ', '
content += repr(k) + ': ' + self.tensor_format(v)
return '{' + content + '}'

def condition(self, x):
return torch.is_tensor(x) or self.is_return_types(x) or \
self.is_list_of_tensors(x) or self.is_tuple_of_tensors(x) or self.is_dict_of_tensors(x)

def compute_repr(self, x):
orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
if torch.is_tensor(x):
return self.tensor_format(x)
elif self.is_return_types(x):
return self.return_types_repr(x)
elif self.is_list_of_tensors(x):
return self.list_of_tensors_repr(x)
elif self.is_tuple_of_tensors(x):
return self.tuple_of_tensors_repr(x)
elif self.is_dict_of_tensors(x):
return self.dict_of_tensors_repr(x)
raise RuntimeError('Control flow should not reach here, open a bug report!')
elif orig_repr_func is not repr:
return orig_repr_func(x)
elif isinstance(x, (list, tuple)):
content = ''
for i in x:
if content != '':
content += ', '
content += self.compute_repr(i)
if isinstance(x, tuple) and len(x) == 1:
content += ','
if isinstance(x, tuple):
return '(' + content + ')'
return '[' + content + ']'
elif isinstance(x, dict):
content = ''
for k, v in x.items():
if content != '':
content += ', '
content += self.compute_repr(k) + ': ' + self.compute_repr(v)
return '{' + content + '}'
return repr(x)


snoop = TorchSnooper

0 comments on commit 208f3b6

Please sign in to comment.