Skip to content

Commit

Permalink
Improve snoop verbose mode (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 10, 2019
1 parent 2ec0ec8 commit f503aab
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 50 deletions.
55 changes: 20 additions & 35 deletions tests/test_snoop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,26 @@ def func():


verbose_expect = '''
21:43:42.09 >>> Call to func in File "test_snoop.py", line 16
21:43:42.09 16 | def func():
21:43:42.09 17 | x = torch.tensor(math.inf)
21:43:42.10 .......... x = tensor(inf)
21:43:42.10 .......... x.shape = ()
21:43:42.10 .......... x.dtype = torch.float32
21:43:42.10 .......... x.device = device(type='cpu')
21:43:42.10 .......... x.requires_grad = False
21:43:42.10 .......... x.has_nan = False
21:43:42.10 .......... x.has_inf = True
21:43:42.10 18 | x = torch.tensor(math.nan)
21:43:42.10 .......... x = tensor(nan)
21:43:42.10 .......... x.has_nan = True
21:43:42.10 .......... x.has_inf = False
21:43:42.10 19 | x = torch.tensor(1.0, requires_grad=True)
21:43:42.10 .......... x = tensor(1., requires_grad=True)
21:43:42.10 .......... x.requires_grad = True
21:43:42.10 .......... x.has_nan = False
21:43:42.10 20 | x = torch.tensor([1.0, math.nan, math.inf])
21:43:42.10 .......... x = tensor([1., nan, inf])
21:43:42.10 .......... x.shape = (3,)
21:43:42.10 .......... x.requires_grad = False
21:43:42.10 .......... x.has_nan = True
21:43:42.10 .......... x.has_inf = True
21:43:42.10 21 | x = numpy.zeros((2, 2))
21:43:42.10 .......... x = array([[0., 0.],
21:43:42.10 [0., 0.]])
21:43:42.10 .......... x.shape = (2, 2)
21:43:42.10 .......... x.dtype = dtype('float64')
21:43:42.10 22 | x = (x, x)
21:43:42.10 .......... x = (array([[0., 0.],
21:43:42.10 [0., 0.]]), array([[0., 0.],
21:43:42.10 [0., 0.]]))
21:43:42.10 .......... len(x) = 2
21:43:42.10 <<< Return value from func: None
01:24:31.56 >>> Call to func in File "test_snoop.py", line 16
01:24:31.56 16 | def func():
01:24:31.56 17 | x = torch.tensor(math.inf)
01:24:31.56 .......... x = tensor<(), float32, cpu, has_inf>
01:24:31.56 .......... x.data = tensor(inf)
01:24:31.56 18 | x = torch.tensor(math.nan)
01:24:31.56 .......... x = tensor<(), float32, cpu, has_nan>
01:24:31.56 .......... x.data = tensor(nan)
01:24:31.56 19 | x = torch.tensor(1.0, requires_grad=True)
01:24:31.56 .......... x = tensor<(), float32, cpu, grad>
01:24:31.56 .......... x.data = tensor(1.)
01:24:31.56 20 | x = torch.tensor([1.0, math.nan, math.inf])
01:24:31.56 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf>
01:24:31.56 .......... x.data = tensor([1., nan, inf])
01:24:31.56 21 | x = numpy.zeros((2, 2))
01:24:31.56 .......... x = ndarray<(2, 2), float64>
01:24:31.56 .......... x.data = <memory at 0x7efc261e9480>
01:24:31.56 22 | x = (x, x)
01:24:31.56 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>)
01:24:31.56 <<< Return value from func: None
'''.strip()

terse_expect = '''
Expand Down
35 changes: 20 additions & 15 deletions torchsnooper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,26 @@ def compute_repr(self, x):

def register_snoop(verbose=False, tensor_format=default_format, numpy_format=default_numpy_format):
import snoop
import cheap_repr
import snoop.configuration
cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x))
cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x))
cheap_repr.cheap_repr(torch.zeros(6))
unwanted = {
snoop.configuration.len_shape_watch,
snoop.configuration.dtype_watch,
}
snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted)
if verbose:

class TensorWrap:

def __init__(self, tensor):
self.tensor = tensor

def __repr__(self):
return self.tensor.__repr__()

snoop.config.watch_extras += (
lambda source, value: ('{}.device'.format(source), value.device),
lambda source, value: ('{}.requires_grad'.format(source), value.requires_grad),
lambda source, value: ('{}.has_nan'.format(source), bool(torch.isnan(value).any())),
lambda source, value: ('{}.has_inf'.format(source), bool(torch.isinf(value).any())),
lambda source, value: ('{}.data'.format(source), TensorWrap(value.data)),
)
else:
import cheap_repr
import snoop.configuration
cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x))
cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x))
cheap_repr.cheap_repr(torch.zeros(6))
unwanted = {
snoop.configuration.len_shape_watch,
snoop.configuration.dtype_watch,
}
snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted)

0 comments on commit f503aab

Please sign in to comment.