diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 58df3a343c..36ea40eb81 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -730,14 +730,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): emulated_config, GemmInputRole.WEIGHT, ) - out_emualted = a_fp8 @ b_fp8 - out_emualted.to(compare_type) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 2e-3, 2e-3 - torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol) + out_emulated = a_fp8 @ b_fp8 + out_emulated.to(compare_type) + sqnr = compute_error(out_padded, out_emulated) + assert sqnr > 50.0 class TestNumerics: