From f503aabdbcceb481b1ef054c1c0114596728847f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 10 Jul 2019 01:26:11 -0400 Subject: [PATCH] Improve snoop verbose mode (#18) --- tests/test_snoop.py | 55 +++++++++++++++------------------------- torchsnooper/__init__.py | 35 ++++++++++++++----------- 2 files changed, 40 insertions(+), 50 deletions(-) diff --git a/tests/test_snoop.py b/tests/test_snoop.py index 431a0df..b82e748 100644 --- a/tests/test_snoop.py +++ b/tests/test_snoop.py @@ -23,41 +23,26 @@ def func(): verbose_expect = ''' -21:43:42.09 >>> Call to func in File "test_snoop.py", line 16 -21:43:42.09 16 | def func(): -21:43:42.09 17 | x = torch.tensor(math.inf) -21:43:42.10 .......... x = tensor(inf) -21:43:42.10 .......... x.shape = () -21:43:42.10 .......... x.dtype = torch.float32 -21:43:42.10 .......... x.device = device(type='cpu') -21:43:42.10 .......... x.requires_grad = False -21:43:42.10 .......... x.has_nan = False -21:43:42.10 .......... x.has_inf = True -21:43:42.10 18 | x = torch.tensor(math.nan) -21:43:42.10 .......... x = tensor(nan) -21:43:42.10 .......... x.has_nan = True -21:43:42.10 .......... x.has_inf = False -21:43:42.10 19 | x = torch.tensor(1.0, requires_grad=True) -21:43:42.10 .......... x = tensor(1., requires_grad=True) -21:43:42.10 .......... x.requires_grad = True -21:43:42.10 .......... x.has_nan = False -21:43:42.10 20 | x = torch.tensor([1.0, math.nan, math.inf]) -21:43:42.10 .......... x = tensor([1., nan, inf]) -21:43:42.10 .......... x.shape = (3,) -21:43:42.10 .......... x.requires_grad = False -21:43:42.10 .......... x.has_nan = True -21:43:42.10 .......... x.has_inf = True -21:43:42.10 21 | x = numpy.zeros((2, 2)) -21:43:42.10 .......... x = array([[0., 0.], -21:43:42.10 [0., 0.]]) -21:43:42.10 .......... x.shape = (2, 2) -21:43:42.10 .......... x.dtype = dtype('float64') -21:43:42.10 22 | x = (x, x) -21:43:42.10 .......... x = (array([[0., 0.], -21:43:42.10 [0., 0.]]), array([[0., 0.], -21:43:42.10 [0., 0.]])) -21:43:42.10 .......... len(x) = 2 -21:43:42.10 <<< Return value from func: None +01:24:31.56 >>> Call to func in File "test_snoop.py", line 16 +01:24:31.56 16 | def func(): +01:24:31.56 17 | x = torch.tensor(math.inf) +01:24:31.56 .......... x = tensor<(), float32, cpu, has_inf> +01:24:31.56 .......... x.data = tensor(inf) +01:24:31.56 18 | x = torch.tensor(math.nan) +01:24:31.56 .......... x = tensor<(), float32, cpu, has_nan> +01:24:31.56 .......... x.data = tensor(nan) +01:24:31.56 19 | x = torch.tensor(1.0, requires_grad=True) +01:24:31.56 .......... x = tensor<(), float32, cpu, grad> +01:24:31.56 .......... x.data = tensor(1.) +01:24:31.56 20 | x = torch.tensor([1.0, math.nan, math.inf]) +01:24:31.56 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf> +01:24:31.56 .......... x.data = tensor([1., nan, inf]) +01:24:31.56 21 | x = numpy.zeros((2, 2)) +01:24:31.56 .......... x = ndarray<(2, 2), float64> +01:24:31.56 .......... x.data = +01:24:31.56 22 | x = (x, x) +01:24:31.56 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>) +01:24:31.56 <<< Return value from func: None '''.strip() terse_expect = ''' diff --git a/torchsnooper/__init__.py b/torchsnooper/__init__.py index 9e08b2c..13acbc1 100644 --- a/torchsnooper/__init__.py +++ b/torchsnooper/__init__.py @@ -155,21 +155,26 @@ def compute_repr(self, x): def register_snoop(verbose=False, tensor_format=default_format, numpy_format=default_numpy_format): import snoop + import cheap_repr + import snoop.configuration + cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x)) + cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x)) + cheap_repr.cheap_repr(torch.zeros(6)) + unwanted = { + snoop.configuration.len_shape_watch, + snoop.configuration.dtype_watch, + } + snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted) if verbose: + + class TensorWrap: + + def __init__(self, tensor): + self.tensor = tensor + + def __repr__(self): + return self.tensor.__repr__() + snoop.config.watch_extras += ( - lambda source, value: ('{}.device'.format(source), value.device), - lambda source, value: ('{}.requires_grad'.format(source), value.requires_grad), - lambda source, value: ('{}.has_nan'.format(source), bool(torch.isnan(value).any())), - lambda source, value: ('{}.has_inf'.format(source), bool(torch.isinf(value).any())), + lambda source, value: ('{}.data'.format(source), TensorWrap(value.data)), ) - else: - import cheap_repr - import snoop.configuration - cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x)) - cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x)) - cheap_repr.cheap_repr(torch.zeros(6)) - unwanted = { - snoop.configuration.len_shape_watch, - snoop.configuration.dtype_watch, - } - snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted)