diff --git a/test/float8/test_base.py b/test/float8/test_base.py index f61ff3738f..a3a1acf49b 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -552,6 +552,43 @@ def test_quantize(self): with torch.no_grad(): m(x) + @unittest.skip( + "TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI" + ) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not is_sm_at_least_89, "CUDA 8.9 not available") + @pytest.mark.parametrize( + "recipe_name", + [ + Float8LinearRecipeName.ALL_TENSORWISE, + # TODO(future PR): enable axiswise recipes + ], + ) + def test_zero_dim(self, recipe_name): + # Note: we only test M == 0 because we can assume that K == 0 and N == 0 + # are not important + M, K, N = 0, 64, 128 + + x0_ref = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).requires_grad_() + x0_fp8 = copy.deepcopy(x0_ref) + config = recipe_name_to_linear_config(recipe_name) + + m_ref = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)) + m_fp8 = copy.deepcopy(m_ref) + m_fp8 = convert_to_float8_training(m_fp8, config=config) + + y_ref = m_ref(x0_ref) + y_ref.sum().backward() + + y_fp8 = m_fp8(x0_fp8) + y_fp8.sum().backward() + + assert torch.allclose(y_ref, y_fp8, rtol=0, atol=0) + assert torch.allclose( + m_ref[0].weight.grad, m_fp8[0].weight.grad, rtol=0, atol=0 + ) + assert torch.allclose(x0_ref.grad, x0_fp8.grad, rtol=0, atol=0) + class TestScaledMM: @unittest.skipIf( diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 29319f3814..643666f5e7 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -99,7 +99,10 @@ def tensor_to_amax( axiswise_dim: Optional[int] = None, ) -> torch.Tensor: if scaling_granularity is ScalingGranularity.TENSORWISE: - amax = torch.max(torch.abs(x)) + if x.numel() > 0: + amax = torch.max(torch.abs(x)) + else: + amax = torch.tensor(EPS, device=x.device, dtype=x.dtype) else: assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" assert axiswise_dim is not None, "unsupported"