Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zero dim support for tensorwise float8 training #1352

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,43 @@ def test_quantize(self):
with torch.no_grad():
m(x)

@unittest.skip(
"TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI"
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_89, "CUDA 8.9 not available")
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_TENSORWISE,
# TODO(future PR): enable axiswise recipes
],
)
def test_zero_dim(self, recipe_name):
# Note: we only test M == 0 because we can assume that K == 0 and N == 0
# are not important
M, K, N = 0, 64, 128

x0_ref = torch.randn(M, K, device="cuda", dtype=torch.bfloat16).requires_grad_()
x0_fp8 = copy.deepcopy(x0_ref)
config = recipe_name_to_linear_config(recipe_name)

m_ref = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16))
m_fp8 = copy.deepcopy(m_ref)
m_fp8 = convert_to_float8_training(m_fp8, config=config)

y_ref = m_ref(x0_ref)
y_ref.sum().backward()

y_fp8 = m_fp8(x0_fp8)
y_fp8.sum().backward()

assert torch.allclose(y_ref, y_fp8, rtol=0, atol=0)
assert torch.allclose(
m_ref[0].weight.grad, m_fp8[0].weight.grad, rtol=0, atol=0
)
assert torch.allclose(x0_ref.grad, x0_fp8.grad, rtol=0, atol=0)


class TestScaledMM:
@unittest.skipIf(
Expand Down
5 changes: 4 additions & 1 deletion torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def tensor_to_amax(
axiswise_dim: Optional[int] = None,
) -> torch.Tensor:
if scaling_granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
if x.numel() > 0:
amax = torch.max(torch.abs(x))
else:
amax = torch.tensor(EPS, device=x.device, dtype=x.dtype)
else:
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
assert axiswise_dim is not None, "unsupported"
Expand Down
Loading