Skip to content

Commit

Permalink
Support bool dtype (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Aug 8, 2019
1 parent f503aab commit 95e179a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
22 changes: 22 additions & 0 deletions tests/test_torchsnooper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,25 @@ def my_function():
ReturnEntry(),
)
)


def test_bool_tensor():
string_io = io.StringIO()

@torchsnooper.snoop(string_io)
def my_function():
x = torch.zeros(5, 5, dtype=torch.bool) # noqa: F841

my_function()

output = string_io.getvalue()
print(output)
assert_output(
output,
(
CallEntry(),
LineEntry(),
VariableEntry('x', "tensor<(), bool, cpu>"),
ReturnEntry(),
)
)
10 changes: 8 additions & 2 deletions torchsnooper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from pkg_resources import get_distribution, DistributionNotFound


FLOATING_POINTS = set()
for i in ['float', 'double', 'half', 'complex128', 'complex32', 'complex64']:
if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
FLOATING_POINTS.add(getattr(torch, i))


try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
Expand Down Expand Up @@ -47,15 +53,15 @@ def __call__(self, tensor):
if tensor.requires_grad:
new += 'grad'
elif p == 'has_nan':
result = bool(torch.isnan(tensor).any())
result = tensor.dtype in FLOATING_POINTS and bool(torch.isnan(tensor).any())
if self.properties_name:
new += 'has_nan='
new += str(result)
else:
if result:
new += 'has_nan'
elif p == 'has_inf':
result = bool(torch.isinf(tensor).any())
result = tensor.dtype in FLOATING_POINTS and bool(torch.isinf(tensor).any())
if self.properties_name:
new += 'has_inf='
new += str(result)
Expand Down

0 comments on commit 95e179a

Please sign in to comment.