Skip to content

Commit

Permalink
Fix SDPA forward pass with softcap + HLFB
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676096028
  • Loading branch information
talumbau authored and copybara-github committed Sep 18, 2024
1 parent 406ac47 commit c0cabe1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
28 changes: 19 additions & 9 deletions ai_edge_torch/generative/layers/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=mask is None,
scale=scale,
)
if softcap is None:
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=mask is None,
scale=scale,
)
else:
q.mul_(scale)
scores = q @ k.transpose(-1, -2)
scores = scores / softcap
scores = torch.tanh(scores)
scores = scores * softcap
scores = scores + mask
out = F.softmax(scores.float(), dim=-1).type_as(q)
y = torch.matmul(out, v)

result = y.transpose(1, 2)
result = builder.mark_outputs(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_gemma(self):
def test_gemma2(self):
config = gemma2.get_fake_model_config()
pytorch_model = gemma2.Gemma2(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
Expand Down

0 comments on commit c0cabe1

Please sign in to comment.