From 44362f567bf333e81578fefe5407fa23aa8bb604 Mon Sep 17 00:00:00 2001 From: Max Conway Date: Thu, 22 Jun 2023 20:53:38 +0100 Subject: [PATCH] Add memory tiling to MultiHeadAttention When the attn_logits tensor is expected to be very large (by default above 1GB), attention calculation is split into chunks sized below this limit. This reduces memory consumption quite a bit and crucially allows for input tensor shapes that would otherwise not fit in memory. There also seems to be a small speed improvement, presumably due to better caching. This is a similar concept to FlashAttention, though a much simpler implementation because it is only tiling in one dimension As an example, calculating self attention with: key_size = 64 num_heads = 16 h = jax.random.normal(jax.random.PRNGKey(42), [1,16000,1024]) This change resulted in: a ~30% speedup (6s to 4.3s) and a 5x reduction in memory usage (19.3GB to 3.9Gb) --- haiku/_src/attention.py | 47 ++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/haiku/_src/attention.py b/haiku/_src/attention.py index a03949261..9f2dd1627 100644 --- a/haiku/_src/attention.py +++ b/haiku/_src/attention.py @@ -67,6 +67,7 @@ def __init__( b_init: Optional[hk.initializers.Initializer] = None, value_size: Optional[int] = None, model_size: Optional[int] = None, + max_chunk_size_mb: Optional[int] = 1000, name: Optional[str] = None, ): """Initialises the module. @@ -85,6 +86,8 @@ def __init__( to the key size (K). model_size: Optional size of the output embedding (D'). If None, defaults to the key size multiplied by the number of heads (K * H). + max_chunk_size_mb: Optional maximum size in mb of the internal logits + tensor, to limit RAM usage. By default 1000. If None, never slice. name: Optional name for this module. """ super().__init__(name=name) @@ -92,6 +95,7 @@ def __init__( self.key_size = key_size self.value_size = value_size or key_size self.model_size = model_size or key_size * num_heads + self.max_chunk_size_mb = max_chunk_size_mb # Backwards-compatibility for w_init_scale. if w_init_scale is not None: @@ -142,20 +146,39 @@ def __call__( key_heads = projection(key, self.key_size, "key") # [T, H, K] value_heads = projection(value, self.value_size, "value") # [T, H, V] - # Compute attention weights. - attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) - attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) - if mask is not None: - if mask.ndim != attn_logits.ndim: - raise ValueError( + # Calculate the size of the full logits tensor in memory + logits_itemsize = max(query_heads.dtype.itemsize,key_heads.dtype.itemsize) + logits_shape = query_heads.size//query_heads.shape[-1]*key_heads.shape[-3] + logits_size_bytes = logits_shape*logits_itemsize + + if self.max_chunk_size_mb is not None: + max_chunk_size_bytes = self.max_chunk_size_mb*1000000 + num_chunks = logits_size_bytes//max_chunk_size_bytes + 1 + else: + num_chunks = 1 + t_chunk_size = query_heads.shape[-3]//num_chunks + + # Compute attention weights, chunk by chunk. + attns = [] + for i in range(0, query_heads.shape[-3], t_chunk_size): + query_chunk = query_heads[...,i:i+t_chunk_size,:,:] + attn_logits = jnp.einsum("...thd,...Thd->...htT", query_chunk, key_heads) + attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) + if mask is not None: + if mask.ndim != attn_logits.ndim: + raise ValueError( f"Mask dimensionality {mask.ndim} must match logits dimensionality " f"{attn_logits.ndim}." - ) - attn_logits = jnp.where(mask, attn_logits, -1e30) - attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] - - # Weight the values by the attention and flatten the head vectors. - attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) + ) + mask_chunk = mask[...,i:i+t_chunk_size,:] + attn_logits = jnp.where(mask_chunk, attn_logits, -1e30) + attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] + # Weight the values by the attention and flatten the head vectors. + attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) + attns.append(attn) + + # Join the chunks back together, and reshape + attn = jnp.concatenate(attns,-3) attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] # Apply another projection to get the final embeddings.