Skip to content

Commit

Permalink
Cache: init empty cache when use_cache (#34274)
Browse files Browse the repository at this point in the history
* fix

* fix tests

* fix copies

* add docs

* Revert "add docs"

This reverts commit 32d3563.

* qwen move deltas

* mllama can potentiall fullgraph compile

* enable mllama compile and fix tests

* remove mllama fixes
  • Loading branch information
zucchini-nlp authored Nov 25, 2024
1 parent 1339a14 commit c1a8520
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 64 deletions.
6 changes: 5 additions & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
Expand Down Expand Up @@ -1300,6 +1300,10 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
Expand Down Expand Up @@ -1618,6 +1618,9 @@ def forward(

hidden_states = inputs_embeds

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
Expand Down Expand Up @@ -1845,7 +1848,7 @@ def __init__(self, config):
super().__init__(config.get_text_config())
self.text_config = config.get_text_config()
self.vocab_size = self.text_config.vocab_size
self.model = MllamaTextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
self.model = MllamaTextModel._from_config(self.text_config)
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)

self.post_init()
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,9 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

Expand Down
90 changes: 30 additions & 60 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -30,7 +30,7 @@
from torch.nn import CrossEntropyLoss, LayerNorm

from ...activations import ACT2FN
from ...cache_utils import Cache, SlidingWindowCache, StaticCache
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
Expand Down Expand Up @@ -549,10 +549,6 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += cache_position[0] + 1

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
Expand Down Expand Up @@ -646,16 +642,6 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -784,9 +770,6 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
Expand Down Expand Up @@ -1116,6 +1099,10 @@ def forward(
)
use_cache = False

# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

Expand Down Expand Up @@ -1428,7 +1415,7 @@ def __init__(self, config):
self.model = Qwen2VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.rope_deltas = None # cache rope_deltas here

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -1507,7 +1494,7 @@ def get_rope_index(
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
Expand Down Expand Up @@ -1600,25 +1587,6 @@ def get_rope_index(

return position_ids, mrope_position_deltas

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)

if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas

return model_kwargs

@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand All @@ -1638,6 +1606,7 @@ def forward(
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1726,8 +1695,24 @@ def forward(
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

if position_ids is None and input_ids is not None:
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

outputs = self.model(
input_ids=None,
Expand All @@ -1739,6 +1724,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -1769,7 +1755,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
rope_deltas=self.rope_deltas,
)

def prepare_inputs_for_generation(
Expand Down Expand Up @@ -1798,22 +1784,6 @@ def prepare_inputs_for_generation(
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

rope_deltas = kwargs.get("rope_deltas", None)
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
else:
batch_size, seq_length = input_ids.shape
delta = (
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
)
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

if cache_position[0] != 0:
pixel_values = None
pixel_values_videos = None
Expand Down Expand Up @@ -1854,7 +1824,7 @@ def prepare_inputs_for_generation(
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"rope_deltas": rope_deltas,
"cache_position": cache_position,
}
)
return model_inputs
8 changes: 8 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,14 @@ def test_past_key_values_format(self):
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads

# some models have diffent num-head for query vs key/value so we need to assign correct value
# BUT only after `per_head_embed_dim` is set
num_attention_heads = (
text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None
else num_attention_heads
)

past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)

Expand Down
4 changes: 4 additions & 0 deletions tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ def test_beam_search_low_memory(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
def test_generate_compile_fullgraph(self):
pass


@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,7 +2343,8 @@ def recursive_check(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
# model might return non-tensors objects (e.g. Cache class)
elif isinstance(tuple_object, torch.Tensor):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
Expand Down

0 comments on commit c1a8520

Please sign in to comment.