Skip to content

Commit

Permalink
add Cache support to Phi+FA2
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and ydshieh committed Dec 7, 2023
1 parent 58618a7 commit d08ac80
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand All @@ -444,7 +444,7 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
Expand All @@ -464,11 +464,8 @@ def forward(
key_states = torch.cat((key_rot, key_pass), dim=-1)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

tgt_len = key_states.shape[2]

Expand Down

0 comments on commit d08ac80

Please sign in to comment.