Skip to content

Commit

Permalink
pipe cache_position, additional debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
lekurile committed Nov 14, 2024
1 parent 519e493 commit e511fb0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def forward(
# This needs to be redesigned later!
layer_head_mask=None,
past_key_value=None,
cache_position=None,
**kwargs):

if x is not None:
Expand Down Expand Up @@ -170,6 +171,7 @@ def forward(
self.norm_w,
self.norm_b,
alibi,
cache_position,
**kwargs)

#presents = (key, value)
Expand Down
21 changes: 10 additions & 11 deletions deepspeed/ops/transformer/inference/ds_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,20 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, t
value_layer = value_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1)
#import pdb; pdb.set_trace()
if layer_past is not None:
#past_key, past_value = layer_past
# NEW 1
# # OLD
# past_key, past_value = layer_past
# # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
# key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
# value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)

# NEW
print(f"cache_position = {cache_position}")
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
print(key_layer)
print(value_layer)
#import pdb; pdb.set_trace()

# NEW 2
#past_key, past_value = DynamicCache.from_legacy_cache(layer_past)
#cache_out = DynamicCache.from_legacy_cache(layer_past)

# OLD
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
#key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
#value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)

#presents = (key_layer, value_layer)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result = torch.matmul(query_layer, key_layer)
Expand Down

0 comments on commit e511fb0

Please sign in to comment.