From 14b19c4ef365f90797e07b2a20caaaaf3901b2d2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 25 Apr 2024 09:02:30 +0000 Subject: [PATCH] working :D --- src/transformers/cache_utils.py | 74 +++++------ src/transformers/generation/utils.py | 41 +++--- .../models/llama/modeling_llama.py | 78 ++++-------- tests/models/llama/test_modeling_llama.py | 118 +++++++++++------- 4 files changed, 154 insertions(+), 157 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2ed663b26256ed..5e3535a9513615 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -6,7 +6,6 @@ from .configuration_utils import PretrainedConfig from .utils import logging - logger = logging.get_logger(__name__) @@ -61,6 +60,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + @property def seen_tokens(self): logger.warning_once( @@ -158,14 +165,6 @@ def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return None - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" legacy_cache = () @@ -332,14 +331,6 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - class StaticCache(Cache): """ @@ -347,8 +338,7 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig): - The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` - required to initialize the static cache. + The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. max_cache_len (`int`): @@ -373,9 +363,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.key_cache.append(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_key_cache) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + self.value_cache.append(new_layer_value_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) def update( self, @@ -394,42 +393,31 @@ def update( value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): - The index of the layer to cache the states for. Kept for backward compatibility + The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` - to know how much of the cache it should overwrite. + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ - new_cache_positions = cache_kwargs.get("cache_position") - k_out = self.key_cache - v_out = self.value_cache + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] - k_out[:, :, new_cache_positions] = key_states - v_out[:, :, new_cache_positions] = value_states + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" + """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after # https://github.com/pytorch/pytorch/issues/120248 is fixed - return (self.key_cache[0, 0].any(dim=-1)).sum() + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + """Returns the maximum sequence length of the cached states.""" return self.max_cache_len - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - device = self.key_cache.device - self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) - device = self.value_cache.device - self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self): - """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" - return None diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9e6a58d3e5a560..62eebaf0b1f6c8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1514,19 +1514,30 @@ def generate( input_ids_length=input_ids_length, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: + raise ValueError( + "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not self._supports_cache_class: + raise ValueError( + "This model does not support the `cache_implementation` argument. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981." + ) if generation_config.cache_implementation == "static": - if model_kwargs.get("past_key_values", False) is not False: - raise ValueError( - "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." - ) cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + model_kwargs["past_key_values"] = cache_cls( + config=self.config, + max_batch_size=batch_size, + max_cache_len=generation_config.max_length, + device=self.device, + dtype=cache_dtype, + ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -1844,14 +1855,6 @@ def typeerror(): **model_kwargs, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not callable(getattr(self, "_reset_cache", None)): - raise ValueError( - "A `static_cache` was used to generate but there was a failure when trying to release the cache. " - " Make sure this model implements a `_reset_cache` function." - ) - self._reset_cache() - return result def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a63d..4f607f315d5d64 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -428,6 +428,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -809,27 +816,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -944,7 +930,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -973,23 +959,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1042,7 +1023,7 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, DynamicCache) else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1058,7 +1039,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1070,9 +1051,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1081,9 +1065,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1164,7 +1148,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1262,13 +1246,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1327,9 +1304,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1388,7 +1362,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1505,7 +1479,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache | List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c8134..22d6759fe41173 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -684,21 +684,22 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu @require_read_token def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 NUM_TOKENS_TO_GENERATE = 40 - EXPECTED_TEXT_COMPLETION = { - 7: [ - "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - 8: [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], - } + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096). + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " + "of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " + "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] prompts = [ "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", + "My favorite all time favorite condiment is ketchup.", ] tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") model = LlamaForCausalLM.from_pretrained( @@ -706,39 +707,70 @@ def test_compile_static_cache(self): ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True - )[0] - new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - return new_token - - batch_size, seq_length = inputs["input_ids"].shape - with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - with CaptureLogger(logging.get_logger(__name__)) as cl: - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 - - text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + + @slow + @require_torch_gpu + @require_read_token + def test_compile_repeated_calls(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory " + "of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my " + "fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="", padding_side="right") + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) @require_torch class CodeLlamaIntegrationTest(unittest.TestCase):