diff --git a/test/NnxTestClasses.py b/test/NnxTestClasses.py index 8d4eed1..7e0e3a0 100644 --- a/test/NnxTestClasses.py +++ b/test/NnxTestClasses.py @@ -213,7 +213,11 @@ def _calculate_global_shift( """Calculate global shift so that the output values are in the range of out_type""" s = tensor.type(torch.float64).std() target_s = 2 ** (out_type._bits - 1) - return torch.ceil(torch.log2(s / target_s)).type(torch.int32) + shift = torch.ceil(torch.log2(s / target_s)).type(torch.int32) + if shift < 1: + return torch.zeros((1,)).type(torch.int32) + else: + return shift @staticmethod def _random_data(_type: IntegerType, shape: Tuple, extremes: Tuple = None):