diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 8ebc4a6e27a1..7e4c88d9872a 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -126,6 +126,7 @@ def forward( # This needs to be redesigned later! layer_head_mask=None, past_key_value=None, + cache_position=None, **kwargs): if x is not None: @@ -170,6 +171,7 @@ def forward( self.norm_w, self.norm_b, alibi, + cache_position, **kwargs) #presents = (key, value) diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 35965795a93b..db249277003f 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -253,21 +253,20 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, t value_layer = value_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1) #import pdb; pdb.set_trace() if layer_past is not None: - #past_key, past_value = layer_past - # NEW 1 + # # OLD + # past_key, past_value = layer_past + # # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] + # key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) + # value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2) + + # NEW + print(f"cache_position = {cache_position}") cache_kwargs = {"cache_position": cache_position} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + print(key_layer) + print(value_layer) #import pdb; pdb.set_trace() - # NEW 2 - #past_key, past_value = DynamicCache.from_legacy_cache(layer_past) - #cache_out = DynamicCache.from_legacy_cache(layer_past) - - # OLD - # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] - #key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) - #value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2) - #presents = (key_layer, value_layer) # Raw attention scores. [batch_size * num_heads, q_length, k_length] matmul_result = torch.matmul(query_layer, key_layer)