Skip to content

Commit

Permalink
Update for DynamicCache() use
Browse files Browse the repository at this point in the history
  • Loading branch information
lekurile committed Nov 14, 2024
1 parent 9a2c209 commit 519e493
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def forward(
input = input.to(target_dtype)

with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
attention_output, kv, context_outputtn_ctx, inp_norm = \
self.attention(input,
input_mask,
head_mask,
Expand All @@ -172,7 +172,7 @@ def forward(
alibi,
**kwargs)

presents = (key, value)
#presents = (key, value)
self.layer_past = presents if layer_past is None else None
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)

Expand All @@ -181,8 +181,10 @@ def forward(

output = output.to(input_type)
if get_present:
output = (output, presents)
output = (output, kv)

#import pdb; pdb.set_trace()
print(f"layer_id = {self.config.layer_id}")
if self.config.return_single_tuple:
return (output, )
elif self.config.return_tuple:
Expand Down
42 changes: 30 additions & 12 deletions deepspeed/ops/transformer/inference/ds_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from .op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp, SoftmaxOp
from transformers import DynamicCache

minus_inf = -10000.0

Expand All @@ -23,6 +24,7 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count
data_type = self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
self.config.layer_id = DeepSpeedSelfAttention.num_layers
self.layer_idx = self.config.layer_id
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
Expand Down Expand Up @@ -89,7 +91,7 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count
torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device)
]

def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids):
def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids, cache_position):
if isinstance(qkv_out, list) or isinstance(qkv_out, tuple):
qkv_out = qkv_out[0]

Expand Down Expand Up @@ -140,6 +142,7 @@ def forward(self,
norm_w=None,
norm_b=None,
alibi=None,
cache_position=None, # TODO: Lev, optional cache position tensor
**kwargs):
if self.attn_qkvw is None:
self._attn_qkvw, self._attn_qkvb = self._merge_qkv()
Expand All @@ -165,20 +168,22 @@ def forward(self,
token_idx = kwargs.get("token_idx", None)
position_ids = kwargs.get("position_ids", None)

context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
#context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
context_layer, kv_layer = self.compute_attention(qkv_out=qkv_out,
input_mask=input_mask,
layer_past=layer_past,
alibi=alibi,
is_prompt=is_prompt,
token_idx=token_idx,
position_ids=position_ids)
position_ids=position_ids,
cache_position=cache_position)

output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
inp_norm = qkv_out[-1]

if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
dist.all_reduce(output, group=self.mp_group)
return (output, key_layer, value_layer, context_layer, inp_norm)
return (output, kv_layer, context_layer, inp_norm)


class BloomSelfAttention(DeepSpeedSelfAttention):
Expand Down Expand Up @@ -221,7 +226,7 @@ def _split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_

return tensor_list

def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids):
def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids, cache_position):
if isinstance(qkv_out, list) or isinstance(qkv_out, tuple):
qkv_out = qkv_out[0]

Expand All @@ -246,13 +251,24 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, t
key_layer = key_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3],
-1).transpose(-1, -2)
value_layer = value_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1)
#import pdb; pdb.set_trace()
if layer_past is not None:
past_key, past_value = layer_past
#past_key, past_value = layer_past
# NEW 1
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
#import pdb; pdb.set_trace()

# NEW 2
#past_key, past_value = DynamicCache.from_legacy_cache(layer_past)
#cache_out = DynamicCache.from_legacy_cache(layer_past)

# OLD
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
#key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
#value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)

presents = (key_layer, value_layer)
#presents = (key_layer, value_layer)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result = torch.matmul(query_layer, key_layer)
# change view to [batch_size, num_heads, q_length, k_length]
Expand Down Expand Up @@ -293,9 +309,11 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, t
context_layer.size(1), context_layer.shape[-1])

context_layer = self._transpose_for_context(context_layer)
key_layer = presents[0]
value_layer = presents[1]
#key_layer = presents[0]
#value_layer = presents[1]

return context_layer, key_layer, value_layer
outputs = (context_layer, layer_past)

return outputs

###################### End of HF modeling_bloom addition ########################

0 comments on commit 519e493

Please sign in to comment.