Skip to content

Commit

Permalink
better safe softmax for pure jax
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 25, 2024
1 parent b8e50f2 commit 7a63d02
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions flaxattention/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def _math_attention_inner(

return scores, post_mod_scores # type: ignore

def make_safe(x: Array, axis: int) -> Array:
masked = jnp.isnan(x) | jnp.isinf(x)
masked_rows = jnp.all(masked, axis=axis, keepdims=True)
zeros = jnp.zeros_like(x)
return jnp.where(masked_rows, zeros, x)

def math_attention(
query: Array,
key: Array,
Expand Down Expand Up @@ -107,8 +101,7 @@ def math_attention(
if use_pallas:
post_mod_scores = pl_softmax(post_mod_scores, axis=-1)
else:
post_mod_scores = jax.nn.softmax(post_mod_scores, axis=-1)
post_mod_scores = make_safe(post_mod_scores, axis=-1)
post_mod_scores = jax.nn.softmax(post_mod_scores, axis=-1, where=(post_mod_scores != -jnp.inf))
output = jnp.matmul(post_mod_scores.astype(query.dtype), value)
return output, logsumexp / jnp.log(2)

Expand Down

0 comments on commit 7a63d02

Please sign in to comment.