From 38c79d46e82ede124bb378937d3520ad9424e162 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 18 Dec 2024 16:15:03 -0800 Subject: [PATCH] fix local float8 tests on H100 (#1438) Summary: float8 tests were failing my machine. I bisected the failure to https://github.com/pytorch/ao/pull/1344. Further investigation found that that PR was fine, but the tolerance for one of the tests was too tight, adding a test changed the random seed of the data, and things started failing. Switching to SQNR for a more robust measurement. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_base.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 58df3a343c..36ea40eb81 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -730,14 +730,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): emulated_config, GemmInputRole.WEIGHT, ) - out_emualted = a_fp8 @ b_fp8 - out_emualted.to(compare_type) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 2e-3, 2e-3 - torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol) + out_emulated = a_fp8 @ b_fp8 + out_emulated.to(compare_type) + sqnr = compute_error(out_padded, out_emulated) + assert sqnr > 50.0 class TestNumerics: