diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index 7692294cd6df..099ae5cdcd28 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -284,8 +284,12 @@ def test_fused_attention_bwd( segment_ids = None def f(q, k, v): - return attention.mha(q, k, v, segment_ids, causal=causal, - interpret=self.INTERPRET).sum() + if jtu.is_device_rocm(): + return attention.mha(q, k, v, segment_ids, 1.0, causal, 64, 64, + interpret=self.INTERPRET).sum() + else: + return attention.mha(q, k, v, segment_ids, causal=causal, + interpret=self.INTERPRET).sum() def f_ref(q, k, v): return attention.mha_reference(q, k, v, segment_ids, causal=causal).sum()