diff --git a/tests/test_torchsnooper.py b/tests/test_torchsnooper.py index c3efb0a..0c5370d 100644 --- a/tests/test_torchsnooper.py +++ b/tests/test_torchsnooper.py @@ -314,3 +314,25 @@ def my_function(): ReturnEntry(), ) ) + + +def test_bool_tensor(): + string_io = io.StringIO() + + @torchsnooper.snoop(string_io) + def my_function(): + x = torch.zeros(5, 5, dtype=torch.bool) # noqa: F841 + + my_function() + + output = string_io.getvalue() + print(output) + assert_output( + output, + ( + CallEntry(), + LineEntry(), + VariableEntry('x', "tensor<(), bool, cpu>"), + ReturnEntry(), + ) + ) diff --git a/torchsnooper/__init__.py b/torchsnooper/__init__.py index 13acbc1..6f4383b 100644 --- a/torchsnooper/__init__.py +++ b/torchsnooper/__init__.py @@ -6,6 +6,12 @@ from pkg_resources import get_distribution, DistributionNotFound +FLOATING_POINTS = set() +for i in ['float', 'double', 'half', 'complex128', 'complex32', 'complex64']: + if hasattr(torch, i): # older version of PyTorch do not have complex dtypes + FLOATING_POINTS.add(getattr(torch, i)) + + try: __version__ = get_distribution(__name__).version except DistributionNotFound: @@ -47,7 +53,7 @@ def __call__(self, tensor): if tensor.requires_grad: new += 'grad' elif p == 'has_nan': - result = bool(torch.isnan(tensor).any()) + result = tensor.dtype in FLOATING_POINTS and bool(torch.isnan(tensor).any()) if self.properties_name: new += 'has_nan=' new += str(result) @@ -55,7 +61,7 @@ def __call__(self, tensor): if result: new += 'has_nan' elif p == 'has_inf': - result = bool(torch.isinf(tensor).any()) + result = tensor.dtype in FLOATING_POINTS and bool(torch.isinf(tensor).any()) if self.properties_name: new += 'has_inf=' new += str(result)