From c0cabe1a7d9844013adc7f12c59e86c503a13676 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Wed, 18 Sep 2024 12:49:20 -0700 Subject: [PATCH] Fix SDPA forward pass with softcap + HLFB PiperOrigin-RevId: 676096028 --- .../layers/scaled_dot_product_attention.py | 28 +++++++++++++------ .../test/test_model_conversion_large.py | 2 +- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/ai_edge_torch/generative/layers/scaled_dot_product_attention.py b/ai_edge_torch/generative/layers/scaled_dot_product_attention.py index 3bcd26c7..ddb508ee 100644 --- a/ai_edge_torch/generative/layers/scaled_dot_product_attention.py +++ b/ai_edge_torch/generative/layers/scaled_dot_product_attention.py @@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb( # Handle the GQA case, where q.shape[1] % k.shape[1] == 0. k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1) v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1) - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=0.0, - is_causal=mask is None, - scale=scale, - ) + if softcap is None: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + is_causal=mask is None, + scale=scale, + ) + else: + q.mul_(scale) + scores = q @ k.transpose(-1, -2) + scores = scores / softcap + scores = torch.tanh(scores) + scores = scores * softcap + scores = scores + mask + out = F.softmax(scores.float(), dim=-1).type_as(q) + y = torch.matmul(out, v) result = y.transpose(1, 2) result = builder.mark_outputs(result) diff --git a/ai_edge_torch/generative/test/test_model_conversion_large.py b/ai_edge_torch/generative/test/test_model_conversion_large.py index 880fdb32..7bd73e29 100644 --- a/ai_edge_torch/generative/test/test_model_conversion_large.py +++ b/ai_edge_torch/generative/test/test_model_conversion_large.py @@ -96,7 +96,7 @@ def test_gemma(self): def test_gemma2(self): config = gemma2.get_fake_model_config() pytorch_model = gemma2.Gemma2(config).eval() - self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3) + self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) @googletest.skipIf( ai_edge_config.Config.use_torch_xla,