diff --git a/setup.py b/setup.py index cd5425a..60d9d11 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ setup_requires=['setuptools_scm'], install_requires=[ 'pysnooper>=0.1.0', + 'numpy', ], tests_require=[ 'pytest', diff --git a/tests/test_torchsnooper.py b/tests/test_torchsnooper.py index 73b1bfe..bdcf532 100644 --- a/tests/test_torchsnooper.py +++ b/tests/test_torchsnooper.py @@ -1,5 +1,6 @@ import io import torch +import numpy import torchsnooper from .utils import assert_output, VariableEntry, CallEntry, LineEntry, ReturnEntry, ReturnValueEntry @@ -258,3 +259,26 @@ def my_function(): ReturnValueEntry("tensor<(3, 3), float32, cpu>"), ) ) + + +def test_numpy_ndarray(): + string_io = io.StringIO() + + @torchsnooper.snoop(string_io) + def my_function(x): + return x + + a = numpy.random.randn(5, 6, 7) + my_function([a, a]) + + output = string_io.getvalue() + print(output) + assert_output( + output, + ( + CallEntry(), + LineEntry(), + ReturnEntry(), + ReturnValueEntry("[ndarray<(5, 6, 7), float64>, ndarray<(5, 6, 7), float64>]"), + ) + ) diff --git a/torchsnooper/__init__.py b/torchsnooper/__init__.py index d5b8a7c..adde1b1 100644 --- a/torchsnooper/__init__.py +++ b/torchsnooper/__init__.py @@ -2,6 +2,7 @@ import pysnooper import pysnooper.utils import warnings +import numpy from pkg_resources import get_distribution, DistributionNotFound @@ -58,14 +59,24 @@ def __call__(self, tensor): default_format = TensorFormat() +class NumpyFormat: + + def __call__(self, x): + return f'ndarray<{x.shape}, {x.dtype.name}>' + + +default_numpy_format = NumpyFormat() + + class TorchSnooper(pysnooper.tracer.Tracer): - def __init__(self, *args, tensor_format=default_format, **kwargs): + def __init__(self, *args, tensor_format=default_format, numpy_format=default_numpy_format, **kwargs): self.orig_custom_repr = kwargs['custom_repr'] if 'custom_repr' in kwargs else () custom_repr = (lambda x: True, self.compute_repr) kwargs['custom_repr'] = (custom_repr,) super(TorchSnooper, self).__init__(*args, **kwargs) self.tensor_format = tensor_format + self.numpy_format = numpy_format @staticmethod def is_return_types(x): @@ -96,6 +107,8 @@ def compute_repr(self, x): orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr) if torch.is_tensor(x): return self.tensor_format(x) + elif isinstance(x, numpy.ndarray): + return self.numpy_format(x) elif self.is_return_types(x): return self.return_types_repr(x) elif orig_repr_func is not repr: