From 9b6ae9c2ff4918efa94637d397b9f94e7b865e8a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 22 Jun 2024 10:24:21 -0700 Subject: [PATCH] compute only query vector when reusing kv --- llmfoundry/models/layers/attention.py | 79 ++++++++++++++++++--------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9bfbcffce3..01832fc823 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -460,18 +460,29 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - self.Wqkv = build_fc( - name=fc_type_name, - in_features=self.d_model, - out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, - fc_kwargs=fc_type, - ) - # for param init fn; enables shape based init of fused layers - fuse_splits = [ - i * self.head_dim - for i in range(1, self.n_heads + 2 * self.kv_n_heads) - ] - self.Wqkv._fused = (0, fuse_splits) + if self.reuse_kv_layer_idx is None: + self.Wqkv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [ + i * self.head_dim + for i in range(1, self.n_heads + 2 * self.kv_n_heads) + ] + self.Wqkv._fused = (0, fuse_splits) + else: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model @@ -480,13 +491,14 @@ def __init__( normalized_shape=norm_size, device=device, ) - if qk_ln: - norm_size = self.head_dim * kv_n_heads - self.k_ln = build_norm( - name=norm_type.lower(), - normalized_shape=norm_size, - device=device, - ) + if self.reuse_kv_layer_idx is None: + if qk_ln: + norm_size = self.head_dim * kv_n_heads + self.k_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) self.attn_fn = attention_implementations.get(self.attn_impl) @@ -564,6 +576,27 @@ def get_qkv( key (torch.Tensor): The key tensor. value (torch.Tensor): The value tensor. """ + if self.reuse_kv_layer_idx is not None: + if prev_layer_key_value is None: + raise ValueError( + 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', + ) + key, value = prev_layer_key_value + + query = self.Wq(x) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + if self.qk_ln or self.qk_gn: + # Applying layernorm to qk + q_shape = query.shape + if self.qk_gn: + b, s = query.shape[:2] + query = query.view(b, s, self.n_heads, -1) + dtype = query.dtype + query = self.q_ln(query).to(dtype).view(q_shape) + return query, key, value + qkv = self.Wqkv(x) if self.clip_qkv: @@ -589,14 +622,6 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) - if self.reuse_kv_layer_idx is not None: - # TODO: We still compute key and values in the code above, even if we end up reusing previous layer's kv cache. We should avoid this wasteful computation. - if prev_layer_key_value is None: - raise ValueError( - 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', - ) - key, value = prev_layer_key_value - return query, key, value def _apply_rotary_embeddings(