Skip to content

Commit

Permalink
fix local float8 tests on H100 (#1438)
Browse files Browse the repository at this point in the history
Summary:

float8 tests were failing my machine. I bisected the failure to
#1344. Further investigation found
that that PR was fine, but the tolerance for one of the tests was too
tight, adding a test changed the random seed of the data, and things
started failing.

Switching to SQNR for a more robust measurement.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Dec 19, 2024
1 parent 2e032c6 commit 38c79d4
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,14 +730,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
emulated_config,
GemmInputRole.WEIGHT,
)
out_emualted = a_fp8 @ b_fp8
out_emualted.to(compare_type)

if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)
out_emulated = a_fp8 @ b_fp8
out_emulated.to(compare_type)
sqnr = compute_error(out_padded, out_emulated)
assert sqnr > 50.0


class TestNumerics:
Expand Down

0 comments on commit 38c79d4

Please sign in to comment.