From 76a1e17da3b71742ee62d510a95b04457f4a179b Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 13:02:58 +0000 Subject: [PATCH] fix parameterized.expand order --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 82c4fb6814551c..16f26fc1e3e2b0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3130,9 +3130,9 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) def test_eager_matches_sdpa_inference(self, torch_dtype: str): if not self.all_model_classes[0]._supports_sdpa: self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")