From d08ac80f8ab48cc46e98d913673811bdc7cbdfc6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 7 Dec 2023 10:32:44 +0000 Subject: [PATCH] add Cache support to Phi+FA2 --- src/transformers/models/phi/modeling_phi.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8044830fe0ddf0..4431c10eebcec4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -417,7 +417,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -444,7 +444,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -464,11 +464,8 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) tgt_len = key_states.shape[2]