Skip to content

Commit

Permalink
compute only query vector when reusing kv
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 22, 2024
1 parent 13802cb commit 9b6ae9c
Showing 1 changed file with 52 additions and 27 deletions.
79 changes: 52 additions & 27 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,18 +460,29 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
if self.reuse_kv_layer_idx is None:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
else:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_size = self.head_dim if qk_gn else d_model
Expand All @@ -480,13 +491,14 @@ def __init__(
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if self.reuse_kv_layer_idx is None:
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

self.attn_fn = attention_implementations.get(self.attn_impl)

Expand Down Expand Up @@ -564,6 +576,27 @@ def get_qkv(
key (torch.Tensor): The key tensor.
value (torch.Tensor): The value tensor.
"""
if self.reuse_kv_layer_idx is not None:
if prev_layer_key_value is None:
raise ValueError(
'prev_layer_key_value is None, cannot reuse_prev_layer_kv.',
)
key, value = prev_layer_key_value

query = self.Wq(x)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
q_shape = query.shape
if self.qk_gn:
b, s = query.shape[:2]
query = query.view(b, s, self.n_heads, -1)
dtype = query.dtype
query = self.q_ln(query).to(dtype).view(q_shape)
return query, key, value

qkv = self.Wqkv(x)

if self.clip_qkv:
Expand All @@ -589,14 +622,6 @@ def get_qkv(
query = self.q_ln(query).to(dtype).view(q_shape)
key = self.k_ln(key).to(dtype).view(k_shape)

if self.reuse_kv_layer_idx is not None:
# TODO: We still compute key and values in the code above, even if we end up reusing previous layer's kv cache. We should avoid this wasteful computation.
if prev_layer_key_value is None:
raise ValueError(
'prev_layer_key_value is None, cannot reuse_prev_layer_kv.',
)
key, value = prev_layer_key_value

return query, key, value

def _apply_rotary_embeddings(
Expand Down

0 comments on commit 9b6ae9c

Please sign in to comment.