Skip to content

Commit

Permalink
update attention cache
Browse files Browse the repository at this point in the history
  • Loading branch information
yikangshen committed May 13, 2024
1 parent 71a2939 commit 9e8b759
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

Expand Down Expand Up @@ -606,9 +605,6 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand Down Expand Up @@ -699,8 +695,6 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand Down Expand Up @@ -1087,7 +1081,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -1161,13 +1157,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return MoeModelOutputWithPast(
Expand Down

0 comments on commit 9e8b759

Please sign in to comment.