Skip to content

Commit

Permalink
Add support for representing numpy.ndarray (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 1, 2019
1 parent 208f3b6 commit f322faa
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
setup_requires=['setuptools_scm'],
install_requires=[
'pysnooper>=0.1.0',
'numpy',
],
tests_require=[
'pytest',
Expand Down
24 changes: 24 additions & 0 deletions tests/test_torchsnooper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import torch
import numpy
import torchsnooper
from .utils import assert_output, VariableEntry, CallEntry, LineEntry, ReturnEntry, ReturnValueEntry

Expand Down Expand Up @@ -258,3 +259,26 @@ def my_function():
ReturnValueEntry("tensor<(3, 3), float32, cpu>"),
)
)


def test_numpy_ndarray():
string_io = io.StringIO()

@torchsnooper.snoop(string_io)
def my_function(x):
return x

a = numpy.random.randn(5, 6, 7)
my_function([a, a])

output = string_io.getvalue()
print(output)
assert_output(
output,
(
CallEntry(),
LineEntry(),
ReturnEntry(),
ReturnValueEntry("[ndarray<(5, 6, 7), float64>, ndarray<(5, 6, 7), float64>]"),
)
)
15 changes: 14 additions & 1 deletion torchsnooper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pysnooper
import pysnooper.utils
import warnings
import numpy
from pkg_resources import get_distribution, DistributionNotFound


Expand Down Expand Up @@ -58,14 +59,24 @@ def __call__(self, tensor):
default_format = TensorFormat()


class NumpyFormat:

def __call__(self, x):
return f'ndarray<{x.shape}, {x.dtype.name}>'


default_numpy_format = NumpyFormat()


class TorchSnooper(pysnooper.tracer.Tracer):

def __init__(self, *args, tensor_format=default_format, **kwargs):
def __init__(self, *args, tensor_format=default_format, numpy_format=default_numpy_format, **kwargs):
self.orig_custom_repr = kwargs['custom_repr'] if 'custom_repr' in kwargs else ()
custom_repr = (lambda x: True, self.compute_repr)
kwargs['custom_repr'] = (custom_repr,)
super(TorchSnooper, self).__init__(*args, **kwargs)
self.tensor_format = tensor_format
self.numpy_format = numpy_format

@staticmethod
def is_return_types(x):
Expand Down Expand Up @@ -96,6 +107,8 @@ def compute_repr(self, x):
orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
if torch.is_tensor(x):
return self.tensor_format(x)
elif isinstance(x, numpy.ndarray):
return self.numpy_format(x)
elif self.is_return_types(x):
return self.return_types_repr(x)
elif orig_repr_func is not repr:
Expand Down

0 comments on commit f322faa

Please sign in to comment.