Skip to content

Commit

Permalink
fix fully masked
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 24, 2024
1 parent eb229b9 commit 0f183a3
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 2 deletions.
7 changes: 6 additions & 1 deletion flaxattention/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def _math_attention_inner(

return scores, post_mod_scores # type: ignore

def make_safe(x: Array, axis: int) -> Array:
masked = jnp.isnan(x) | jnp.isinf(x)
masked_rows = jnp.all(masked, axis=axis, keepdims=True)
zeros = jnp.zeros_like(x)
return jnp.where(masked_rows, zeros, x)

def math_attention(
query: Array,
Expand Down Expand Up @@ -103,7 +108,7 @@ def math_attention(
post_mod_scores = pl_softmax(post_mod_scores, axis=-1)
else:
post_mod_scores = jax.nn.softmax(post_mod_scores, axis=-1)

post_mod_scores = make_safe(post_mod_scores, axis=-1)
output = jnp.matmul(post_mod_scores.astype(query.dtype), value)
return output, logsumexp / jnp.log(2)

Expand Down
3 changes: 3 additions & 0 deletions flaxattention/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def body(start_k, carry):

m_curr = qk.max(axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
m_next = (m_next == DEFAULT_MASK_VALUE)
m_next = jnp.where(m_next, 0.0, m_next)
correction = jnp.exp(m_prev - m_next)
l_prev_corr = correction * l_prev
s_curr = jnp.exp(
Expand All @@ -139,6 +141,7 @@ def body(start_k, carry):
# We keep an unscaled version of o during the scan over seq_len. Scaling it
# by the last l_i gives us the correct final output. See section 3.1.1 in the
# FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691.
l_i = jnp.where(l_i == 0.0, 1, l_i)
o /= l_i[:, None]

if residual_refs:
Expand Down
116 changes: 115 additions & 1 deletion tests/attention_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import jax
from jax import Array
import jax.numpy as jnp
import torch
import torch.utils.dlpack
from torch.nn.attention.flex_attention import flex_attention
import numpy as np
from absl.testing import absltest # pylint: disable=import-error

from flaxattention import flax_attention, flax_attention_pallas
from flaxattention import flax_attention, flax_attention_pallas, create_block_mask
from flaxattention.mods import generate_alibi_bias


Expand Down Expand Up @@ -61,6 +62,72 @@ def test_equivalence_with_torch(self):

np.testing.assert_almost_equal(output_jax, output_torch, decimal=2)

def test_fully_masked(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 64
seq_len_kv = 64
feature_size = 32

def fullmask(batch_idx, head_idx, q_idx, kv_idx):
return False

block_mask = create_block_mask(fullmask, batch_size, num_heads, seq_len_q, seq_len_kv)

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

output = flax_attention(
query,
key,
value,
block_mask=block_mask,
)

np.testing.assert_almost_equal(output, np.zeros_like(output), decimal=2)

def test_pallas_fully_masked(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 64
seq_len_kv = 64
feature_size = 32

def fullmask(batch_idx, head_idx, q_idx, kv_idx):
return False

# block_mask = create_block_mask(fullmask, batch_size, num_heads, seq_len_q, seq_len_kv)

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

output = flax_attention_pallas(
query,
key,
value,
mask_mod=fullmask,
)

np.testing.assert_almost_equal(output, np.zeros_like(output), decimal=2)

def test_pallas_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
Expand Down Expand Up @@ -362,6 +429,53 @@ def fn(query, key, value):

np.testing.assert_almost_equal(grad_jax, grad_torch, decimal=2)

def test_pallas_autograd_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 64
seq_len_kv = 64
feature_size = 32

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

def fn(query, key, value):
return flax_attention_pallas(
query,
key,
value,
).sum()

grad_fn = jax.grad(fn, 0)
grad_jax = grad_fn(query, key, value)

query_torch = jax2torch(query)
key_torch = jax2torch(key)
value_torch = jax2torch(value)

query_torch.requires_grad = True

output_torch = flex_attention(
query_torch,
key_torch,
value_torch,
).sum()

output_torch.backward()

grad_torch = query_torch.grad.cpu().numpy()

np.testing.assert_almost_equal(grad_jax, grad_torch, decimal=2)

def test_decoding_autograd_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
Expand Down

0 comments on commit 0f183a3

Please sign in to comment.