Skip to content

Commit

Permalink
[wip] zero dim support for float8 training
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Nov 26, 2024
1 parent b65e513 commit a130365
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
37 changes: 37 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,43 @@ def test_inference_mode(self):
with torch.inference_mode(mode=True):
m(x)

@unittest.skip(
"TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI"
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_TENSORWISE,
# TODO(future PR): enable axiswise recipes
],
)
def test_zero_dim(self, recipe_name):
# Note: we only test M == 0 because we can assume that K == 0 and N == 0
# are not important
M, K, N = 0, 64, 128

x0_ref = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).requires_grad_()
x0_fp8 = copy.deepcopy(x0_ref)
config = recipe_name_to_linear_config(recipe_name)

m_ref = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16))
m_fp8 = copy.deepcopy(m_ref)
m_fp8 = convert_to_float8_training(m_fp8, config=config)

y_ref = m_ref(x0_ref)
y_ref.sum().backward()

y_fp8 = m_fp8(x0_fp8)
y_fp8.sum().backward()

assert torch.allclose(y_ref, y_fp8, rtol=0, atol=0)
assert torch.allclose(
m_ref[0].weight.grad, m_fp8[0].weight.grad, rtol=0, atol=0
)
assert torch.allclose(x0_ref.grad, x0_fp8.grad, rtol=0, atol=0)


class TestScaledMM:
@unittest.skipIf(
Expand Down
5 changes: 4 additions & 1 deletion torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def tensor_to_amax(
axiswise_dim: Optional[int] = None,
) -> torch.Tensor:
if scaling_granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
if x.numel() > 0:
amax = torch.max(torch.abs(x))
else:
amax = torch.tensor(EPS, device=x.device, dtype=x.dtype)
else:
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
assert axiswise_dim is not None, "unsupported"
Expand Down

0 comments on commit a130365

Please sign in to comment.