Skip to content

Commit

Permalink
add visualization, attention sink
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 22, 2024
1 parent c423b78 commit 62298b5
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 2 deletions.
2 changes: 2 additions & 0 deletions flaxattention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
or_masks,
)
from flaxattention.core.common import _mask_mod_signature, _score_mod_signature
from .utils import visualize_attention_scores

__all__ = [
"math_attention",
Expand All @@ -17,4 +18,5 @@
"or_masks",
"_mask_mod_signature",
"_score_mod_signature",
"visualize_attention_scores",
]
49 changes: 49 additions & 0 deletions flaxattention/masks/attention_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Attention Sink in Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/abs/2309.17453)"""

from flaxattention import _mask_mod_signature, or_masks, and_masks
from flaxattention.masks import causal_mask

def generate_attention_sink(window_size: int, sink_size: int = 4) -> _mask_mod_signature:
"""Generates an attention sink mask with a given window size and sink size.
Args:
window_size: The size of the sliding window.
sink_size: The size of the attention sink.
Note:
We assume that the window size represents the lookback size and we mask out all future tokens
similar to causal masking.
"""
def sliding_window(b, h, q_idx, kv_idx):
return q_idx - kv_idx <= window_size

def attention_sink(b, h, q_idx, kv_idx):
return kv_idx <= sink_size

attention_sink_mask = and_masks(or_masks(attention_sink, sliding_window), causal_mask)
attention_sink_mask.__name__ = f"attention_sink_{window_size}_{sink_size}"
return attention_sink_mask

def main(device: str = "cpu"):
"""Visualize the attention scores of causal masking.
Args:
device (str): Device to use for computation. Defaults
"""
from flaxattention.utils import visualize_attention_scores
import jax.numpy as jnp

B, H, SEQ_LEN, HEAD_DIM = 1, 1, 128, 8

def make_tensor():
return jnp.ones((B, H, SEQ_LEN, HEAD_DIM))

query, key = make_tensor(), make_tensor()

visualize_attention_scores(query, key, mask_mod=generate_attention_sink(32, 4), name="attention_sink")

if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)
25 changes: 25 additions & 0 deletions flaxattention/masks/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,28 @@

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

def main(device: str = "cpu"):
"""Visualize the attention scores of causal masking.
Args:
device (str): Device to use for computation. Defaults
"""
from flaxattention.utils import visualize_attention_scores
import jax.numpy as jnp

B, H, SEQ_LEN, HEAD_DIM = 1, 1, 128, 8

def make_tensor():
return jnp.ones((B, H, SEQ_LEN, HEAD_DIM))

query, key = make_tensor(), make_tensor()

visualize_attention_scores(query, key, mask_mod=causal_mask, name="causal_mask")

if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)
25 changes: 25 additions & 0 deletions flaxattention/masks/sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,28 @@ def sliding_window(b, h, q_idx, kv_idx):
sliding_window_mask = and_masks(sliding_window, causal_mask)
sliding_window_mask.__name__ = f"sliding_window_{window_size}"
return sliding_window_mask

def main(device: str = "cpu"):
"""Visualize the attention scores of causal masking.
Args:
device (str): Device to use for computation. Defaults
"""
from flaxattention.utils import visualize_attention_scores
import jax.numpy as jnp

B, H, SEQ_LEN, HEAD_DIM = 1, 1, 128, 8

def make_tensor():
return jnp.ones((B, H, SEQ_LEN, HEAD_DIM))

query, key = make_tensor(), make_tensor()

visualize_attention_scores(query, key, mask_mod=generate_sliding_window(32), name="sliding_window")

if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)
25 changes: 25 additions & 0 deletions flaxattention/mods/alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,28 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
return score + bias

return alibi_mod

def main(device: str = "cpu"):
"""Visualize the attention scores of causal masking.
Args:
device (str): Device to use for computation. Defaults
"""
from flaxattention.utils import visualize_attention_scores
import jax.numpy as jnp

B, H, SEQ_LEN, HEAD_DIM = 1, 1, 128, 8

def make_tensor():
return jnp.ones((B, H, SEQ_LEN, HEAD_DIM))

query, key = make_tensor(), make_tensor()

visualize_attention_scores(query, key, score_mod=generate_alibi_bias(H), name="alibi_bias")

if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)
4 changes: 2 additions & 2 deletions tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_equivalence_with_torch(self):
.numpy()
)

np.testing.assert_almost_equal(output_jax, output_torch, decimal=3)
np.testing.assert_almost_equal(output_jax, output_torch, decimal=2)

def test_gqa(self):
# Prepare inputs
Expand Down Expand Up @@ -196,4 +196,4 @@ def fn(query, key, value):

grad_torch = query_torch.grad.cpu().numpy()

np.testing.assert_almost_equal(grad_jax, grad_torch, decimal=3)
np.testing.assert_almost_equal(grad_jax, grad_torch, decimal=2)
Binary file added visualizations/alibi_bias.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added visualizations/attention_sink.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added visualizations/causal_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added visualizations/sliding_window.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 62298b5

Please sign in to comment.