Skip to content

Commit

Permalink
formatting fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
cptspacemanspiff committed Dec 29, 2024
1 parent d5a7cd8 commit 1488c34
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ def update(
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

bz = key_states.shape[0]
bz = key_states.shape[0]

if cache_position is None:
k_out.copy_(key_states)
Expand Down
16 changes: 9 additions & 7 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def forward(
key_states_full = curr_past_key_value.key_cache[self.layer_idx]
value_states_full = curr_past_key_value.value_cache[self.layer_idx]
# slice into state b/c cache may be larger:
key_states = key_states_full[:batch_size,:,:cross_seq_length,:]
value_states = value_states_full[:batch_size,:,:cross_seq_length,:]
key_states = key_states_full[:batch_size, :, :cross_seq_length, :]
value_states = value_states_full[:batch_size, :, :cross_seq_length, :]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
Expand All @@ -516,22 +516,24 @@ def forward(
# save all key/value_states to cache to be re-used for fast auto-regressive generation
if is_cross_attention:
cross_seq_length = current_states.shape[1]
cache_position = torch.arange(0,cross_seq_length) # Save into specific cells is there a alternative to passing a varying length position?
cache_position = torch.arange(
0, cross_seq_length
) # Save into specific cells is there a alternative to passing a varying length position?
key_states_full, value_states_full = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_value.is_updated[self.layer_idx] = True
# slice into state b/c cache may be larger:
key_states = key_states_full[:batch_size,:,:cross_seq_length,:]
value_states = value_states_full[:batch_size,:,:cross_seq_length,:]
key_states = key_states_full[:batch_size, :, :cross_seq_length, :]
value_states = value_states_full[:batch_size, :, :cross_seq_length, :]
else:
cache_position = cache_position
key_states_full, value_states_full = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
key_states = key_states_full[:batch_size,:,:,:]
value_states = value_states_full[:batch_size,:,:,:]
key_states = key_states_full[:batch_size, :, :, :]
value_states = value_states_full[:batch_size, :, :, :]

# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))
Expand Down

0 comments on commit 1488c34

Please sign in to comment.