Skip to content

Commit

Permalink
fix shape check
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 22, 2024
1 parent 039058a commit df73d2c
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Float16:
|------------------------|----------------------------|----------------------------|
| FlaxAttention (Pure JAX) | 0.5692746052518487 | 0.8823547409847379 |
| FlaxAttention (Pallas) | **0.13677988620474935** | **0.5575501238927245** |
| Jax Attention (no score_mod) | 0.2551286369562149 | 0.04072062578052282 |
| Jax Attention (no score_mod) | 1.6788566000759602 | 1.0905949068255723 |
| FlexAttention (Torch)| **0.11708855209872127** | **0.5104729640297592** |

We can see that the forward performance is about 20% slower than the original implementation, while backward about 8% slower. There are still some optimizations to be done.
19 changes: 11 additions & 8 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,23 @@ def fn(query, key, value):

# try jax attention

query_transposed = jnp.moveaxis(query, 1, 2)
key_transposed = jnp.moveaxis(key, 1, 2)
value_transposed = jnp.moveaxis(value, 1, 2)
# warm up
output = dot_product_attention(
query,
key,
value,
query_transposed,
key_transposed,
value_transposed,
)
output.block_until_ready()

start = timer()
for _ in range(100):
output = dot_product_attention(
query,
key,
value,
query_transposed,
key_transposed,
value_transposed,
)
output.block_until_ready()
end = timer()
Expand All @@ -184,14 +187,14 @@ def fn1(query, key, value):
grad_fn1 = jax.jit(grad_fn1)

# warm up
grad = grad_fn1(query, key, value)
grad = grad_fn1(query_transposed, key_transposed, value_transposed)
grad.block_until_ready()

# print(grad[0, 0, 0])

start = timer()
for _ in range(100):
grad = grad_fn1(query, key, value)
grad = grad_fn1(query_transposed, key_transposed, value_transposed)
grad.block_until_ready()
end = timer()
print("Jax dot product attention gradient time taken (no score_mod):", end - start)
Expand Down
8 changes: 4 additions & 4 deletions flaxattention/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ def flax_attention_pallas(
_validate_embed_dim(query, key, value)
if query.ndim != 4 or key.ndim != 4 or value.ndim != 4:
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
if (not enable_gqa) and query.shape[-3] != key.shape[-3]:
if (not enable_gqa) and query.shape[-2] != key.shape[-2]:
raise ValueError(
f"Expect query and key/value to have the same number of heads "
f"but got Hq={query.shape[-3]} and Hkv={key.shape[-3]}. "
f"but got Hq={query.shape[-2]} and Hkv={key.shape[-2]}. "
f"Try setting enable_gqa=True for GQA."
)
if enable_gqa:
Hq = query.shape[1]
Hkv = key.shape[1]
Hq = query.shape[2]
Hkv = key.shape[2]
if Hq % Hkv != 0:
raise ValueError(
f"Expect number of query heads to be a multiple of kv heads for GQA "
Expand Down
139 changes: 139 additions & 0 deletions flaxattention/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import jax
from jax import numpy as jnp
from typing import Optional
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import math
from .core.common import (
_score_mod_signature,
_mask_mod_signature,
_vmap_for_bhqkv,
_ModificationType,
)

Array = jax.Array

def create_score_mod(
query: Array,
key: Array,
score_mod: Optional[_score_mod_signature],
mask_mod: Optional[_mask_mod_signature],
scale: Optional[float] = None,
batch_idx: int = 0,
head_idx: int = 0,
) -> Array:
B = 1
H = 1
M = query.shape[0]
N = key.shape[0]

b = jnp.arange(0, B) + batch_idx
h = jnp.arange(0, H) + head_idx
m = jnp.arange(0, M)
n = jnp.arange(0, N)

scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
type = _ModificationType.SCORE_MOD if score_mod is not None else _ModificationType.MASK_MOD
mod_fn = score_mod if type == _ModificationType.SCORE_MOD else mask_mod
prefix = (0,) if type == _ModificationType.SCORE_MOD else ()
mod = _vmap_for_bhqkv(mod_fn, prefix=prefix)
scores = query @ jnp.moveaxis(key, -1, -2)
scores *= scale_factor
scores = scores.reshape(1, 1, M, N)
if type == _ModificationType.SCORE_MOD:
out = mod(scores, b, h, m, n)
else:
out = mod(b, h, m, n)

return out


def _name_to_title(name: str) -> str:
title = name.replace("_", " ")
title = " ".join(word.capitalize() for word in title.split())
return title


def visualize_attention_scores(
query: Array,
key: Array,
score_mod: Optional[_score_mod_signature] = None,
mask_mod: Optional[_mask_mod_signature] = None,
device: str = "cuda",
name: str = "attention_scores",
path: Optional[Path] = None,
batch_idx: int = 0,
head_idx: int = 0,
scale: Optional[float] = None,
):
"""
Generate and save a visualization of attention scores.
Args:
query (Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim).
key (Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim).
score_mod (Optional[Callable]): If this is set this will take precedence over the mask_mod.
mask_mod (Optional[Callable]): The mask_mod function used to create block_mask
device (str): Device to run computations on (default: "cuda").
name (str): Base name for the file and title (default: 'attention_scores').
path (Path): Path to save the visualization. If None, will be saved to the current working directory.
batch_idx (int): Index of the batch to visualize (default: 0).
head_idx (int): Index of the head to visualize (default: 0).
scale (float): Scale factor to apply to the attention scores. If None, will be set to 1 / sqrt(head_dim).
Returns:
None
"""
assert (
score_mod is not None or mask_mod is not None
), "Must provide either score_mod or mask_mod"
query = query[batch_idx, head_idx, :, :]
key = key[batch_idx, head_idx, :, :]
scores_viz = create_score_mod(
query,
key,
score_mod=score_mod,
mask_mod=mask_mod,
scale=scale,
batch_idx=batch_idx,
head_idx=head_idx,
)

suffix_title = f"Batch {batch_idx}, Head {head_idx}" if batch_idx != 0 or head_idx != 0 else ""

fig, ax = plt.subplots(figsize=(12, 10))
color = "viridis" if score_mod is not None else "cividis"
im = ax.imshow(scores_viz[0, 0, :, :], aspect="auto", cmap=color)
fig.colorbar(im)

title = _name_to_title(name)
file_path = Path(name).with_suffix(".png") if path is None else path.with_suffix(".png")
ax.set_title(f"{title}\n{suffix_title}", fontsize=20)

ax.set_xlabel("Key Tokens", fontsize=18)
ax.set_ylabel("Query Tokens", fontsize=18)

# Move y-axis ticks and labels to the top
ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False)

# Add tick labels if the number of tokens is manageable
num_query_tokens, num_kv_tokens = scores_viz.shape[-2:]
if num_query_tokens <= 32 and num_kv_tokens <= 32:
ax.set_xticks(range(num_kv_tokens))
rotation = 45 if num_kv_tokens > 12 else 0
ax.set_xticklabels(
[f"KV{i}" for i in range(num_kv_tokens)], fontsize=16, rotation=rotation
)
ax.set_yticks(range(num_query_tokens))
ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)], fontsize=16)
# Align grid with pixel boundaries
ax.set_xticks(np.arange(-0.5, num_kv_tokens, 1), minor=True)
ax.set_yticks(np.arange(-0.5, num_query_tokens, 1), minor=True)
ax.grid(which="minor", color="black", linestyle="-", linewidth=2)

plt.tight_layout()
plt.savefig(file_path, dpi=300, bbox_inches="tight")
plt.close(fig) # Close the figure to free up memory

print(f"Visualization saved as {file_path}")

0 comments on commit df73d2c

Please sign in to comment.