Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiHeadAttention memory usage reduction via tiling #679

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions haiku/_src/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
b_init: Optional[hk.initializers.Initializer] = None,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
max_chunk_size_mb: Optional[int] = 1000,
name: Optional[str] = None,
):
"""Initialises the module.
Expand All @@ -85,13 +86,16 @@ def __init__(
to the key size (K).
model_size: Optional size of the output embedding (D'). If None, defaults
to the key size multiplied by the number of heads (K * H).
max_chunk_size_mb: Optional maximum size in mb of the internal logits
tensor, to limit RAM usage. By default 1000. If None, never slice.
name: Optional name for this module.
"""
super().__init__(name=name)
self.num_heads = num_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_heads
self.max_chunk_size_mb = max_chunk_size_mb

# Backwards-compatibility for w_init_scale.
if w_init_scale is not None:
Expand Down Expand Up @@ -142,20 +146,39 @@ def __call__(
key_heads = projection(key, self.key_size, "key") # [T, H, K]
value_heads = projection(value, self.value_size, "value") # [T, H, V]

# Compute attention weights.
attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads)
attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype)
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
# Calculate the size of the full logits tensor in memory
logits_itemsize = max(query_heads.dtype.itemsize,key_heads.dtype.itemsize)
logits_shape = query_heads.size//query_heads.shape[-1]*key_heads.shape[-3]
logits_size_bytes = logits_shape*logits_itemsize

if self.max_chunk_size_mb is not None:
max_chunk_size_bytes = self.max_chunk_size_mb*1000000
num_chunks = logits_size_bytes//max_chunk_size_bytes + 1
else:
num_chunks = 1
t_chunk_size = query_heads.shape[-3]//num_chunks

# Compute attention weights, chunk by chunk.
attns = []
for i in range(0, query_heads.shape[-3], t_chunk_size):
query_chunk = query_heads[...,i:i+t_chunk_size,:,:]
attn_logits = jnp.einsum("...thd,...Thd->...htT", query_chunk, key_heads)
attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype)
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits) # [H, T', T]

# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
)
mask_chunk = mask[...,i:i+t_chunk_size,:]
attn_logits = jnp.where(mask_chunk, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits) # [H, T', T]
# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
attns.append(attn)

# Join the chunks back together, and reshape
attn = jnp.concatenate(attns,-3)
attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V]

# Apply another projection to get the final embeddings.
Expand Down