Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable StaticCache for assisted generation #34797

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c205b2e
enable StaticCache for assisted generation
yao-matrix Nov 19, 2024
30021dd
update
yao-matrix Nov 19, 2024
620c861
remove warnings import
yao-matrix Nov 19, 2024
b5283e9
enable StaticCache for assisted generation
yao-matrix Nov 19, 2024
71b7d22
update
yao-matrix Nov 19, 2024
c967bbe
remove warnings import
yao-matrix Nov 19, 2024
980aa08
done
yao-matrix Nov 20, 2024
c79411d
done
yao-matrix Nov 20, 2024
c717652
fix review comments
yao-matrix Nov 22, 2024
67618e5
Merge branch 'main' of https://github.com/yao-matrix/transformers
yao-matrix Nov 22, 2024
c8e2428
Merge branch 'main' into main
yao-matrix Nov 26, 2024
fde7ebd
Merge branch 'main' into main
yao-matrix Nov 26, 2024
e1169a3
Merge branch 'main' into main
yao-matrix Nov 29, 2024
8a9a753
add static cache ci
Nov 29, 2024
b74a7fe
Merge branch 'main' of https://github.com/yao-matrix/transformers
Nov 29, 2024
e67a3fd
ship Gemma2 StaticCache CI since it uses HybridCache
Nov 29, 2024
177634c
ruff format
Nov 29, 2024
45d0410
fix phimoe ci
yao-matrix Nov 29, 2024
a33f660
fix mixtral ci
yao-matrix Nov 29, 2024
ff07e47
fix ci
yao-matrix Nov 29, 2024
fbef806
cont.
yao-matrix Nov 29, 2024
3564a87
ci
yao-matrix Nov 29, 2024
b9cf597
fix ci
yao-matrix Dec 2, 2024
803166d
ci
yao-matrix Dec 2, 2024
df1594c
ci
yao-matrix Dec 2, 2024
87b7f15
ci
yao-matrix Dec 2, 2024
e60b1fe
Merge branch 'main' into main
yao-matrix Dec 2, 2024
5e195c2
ci
yao-matrix Dec 2, 2024
759da36
ci
yao-matrix Dec 2, 2024
093b647
ci
yao-matrix Dec 2, 2024
3775dc2
ci
yao-matrix Dec 2, 2024
817d303
ci
yao-matrix Dec 2, 2024
6e2ad2a
add # Ignore copy
yao-matrix Dec 2, 2024
99b6bc2
using a smarter way, ignore in test_utils
yao-matrix Dec 2, 2024
af33391
ci
yao-matrix Dec 2, 2024
587b55f
skip Gemma2, it declars support static cache, but it's hybrid cache a…
yao-matrix Dec 2, 2024
0a49d6f
refine error message
yao-matrix Dec 2, 2024
1deeb55
Merge branch 'main' into main
yao-matrix Dec 4, 2024
62b70e4
Merge branch 'main' into main
yao-matrix Dec 5, 2024
210c2e0
Merge branch 'main' into main
yao-matrix Dec 10, 2024
7b97aa4
add test case test_assisted_decoding_compile
yao-matrix Dec 11, 2024
9cb45da
Merge branch 'main' into main
yao-matrix Dec 11, 2024
93cd7bf
fix bug
yao-matrix Dec 11, 2024
3cc23d7
Merge branch 'main' into main
yao-matrix Dec 13, 2024
b45336c
Merge branch 'main' into main
yao-matrix Dec 15, 2024
04f2ea1
Merge branch 'main' into main
yao-matrix Dec 18, 2024
b08d1fc
Merge branch 'main' into main
yao-matrix Dec 19, 2024
0904268
Merge branch 'main' into main
yao-matrix Dec 20, 2024
dd148a8
Merge branch 'main' into main
yao-matrix Dec 22, 2024
3171476
cohere2 is HybridCache
Dec 23, 2024
4e064ec
Merge branch 'main' into main
yao-matrix Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,31 @@ def update(

return k_out, v_out

def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
seq_length = self.get_seq_length()
# In case it is negative
if max_length < 0:
max_length = seq_length - abs(max_length)

if seq_length <= max_length:
return

begin = max_length
end = seq_length + 1
index = torch.arange(begin, end, device=self.key_cache[0].device)

self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
try:
self.key_cache[idx].index_fill_(2, index, 0)
self.value_cache[idx].index_fill_(2, index, 0)
except NotImplementedError:
# The operator 'aten::index_fill' is not currently implemented for the MPS device.
self.key_cache[idx][:, :, index] = 0
self.value_cache[idx][:, :, index] = 0

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch

from ..cache_utils import DynamicCache
from ..cache_utils import Cache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor

Expand Down Expand Up @@ -177,9 +177,6 @@ def __init__(
"Please pass in `min_length` into `.generate()` instead"
)

# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand Down Expand Up @@ -229,6 +226,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.generation_config.cache_implementation = None

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
Expand Down Expand Up @@ -748,7 +746,7 @@ def _crop_past_key_values(model, past_key_values, max_length):
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache):
elif isinstance(past_key_values, Cache):
past_key_values.crop(max_length)
elif past_key_values is not None:
for idx in range(len(past_key_values)):
Expand Down
24 changes: 12 additions & 12 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def prepare_inputs_for_generation(
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input

# 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
# 6. Create 4D attention mask if we are using a `StaticCache` (important for performant compiled forward pass)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
Expand Down Expand Up @@ -1749,16 +1749,6 @@ def _prepare_cache_for_generation(
return

# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`

# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if assistant_model is not None and generation_config.cache_implementation is not None:
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None

if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
Expand All @@ -1773,6 +1763,15 @@ def _prepare_cache_for_generation(
device=device,
model_kwargs=model_kwargs,
)
if assistant_model is not None:
assistant_model._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences)
* batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
Expand Down Expand Up @@ -2119,6 +2118,7 @@ def generate(
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]

self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)
Expand Down Expand Up @@ -2172,7 +2172,7 @@ def generate(
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
if generation_config.cache_implementation in ["hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with Static cache classes`")
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
Expand Down
5 changes: 4 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,8 +2088,9 @@ def test_generate_methods_with_num_logits_to_keep(self):
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())

@parameterized.expand([(None, True), ("static", False)])
@pytest.mark.generate
def test_assisted_decoding_with_num_logits_to_keep(self):
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
Expand All @@ -2114,6 +2115,8 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
"assistant_model": assistant_model,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": cache_implementation,
"return_legacy_cache": return_legacy_cache,
}

# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
Expand Down
6 changes: 6 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
def test_assisted_decoding_sample(self):
pass

@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest("Gemma2 has HybridCache which is not compatible with assisted decoding StaticCache")
pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not skip entirely, but only the static_cache test, as we still need to check if assisted generation works in Gemma2 :)

Maybe it will be skipped by the model._support_static_cache as I've commented above, but if not we can skip only the test_assisted_decoding_with_num_logits_to_keep_1_static (maybe it's called a bit differently)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i switch to _supports_static_cache to skip the case. For Gemma, it's a bit different, since it's using HybridCache and claims _supports_static_cache = True, I still skip it in model test file. Will remove this skip after enable HybridCache for assisted decoding, I plan to enable it after this PR(pure StaticCache) merged, thx.

@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
Expand Down
9 changes: 9 additions & 0 deletions tests/models/jetmoe/test_modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,15 @@ def test_past_key_values_format(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="JetMoe flash attention does not support right padding")

# Copied from tests.models.phimoe.test_modeling_phimoe.PhimoeModelTest.test_assisted_decoding_with_num_logits_to_keep with phimoe->jetmoe, Phimoe->JetMoe
@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest(
"JetMoe doesn't support StaticCache, please check the following issue -> https://github.com/huggingface/transformers/issues/28981."
)
pass

zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

@require_torch
class JetMoeIntegrationTest(unittest.TestCase):
Expand Down
10 changes: 10 additions & 0 deletions tests/models/mixtral/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import pytest
from parameterized import parameterized

from transformers import MixtralConfig, is_torch_available
from transformers.testing_utils import (
Expand Down Expand Up @@ -421,6 +422,15 @@ def test_past_key_values_format(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Mixtral flash attention does not support right padding")

# Copied from tests.models.phimoe.test_modeling_phimoe.PhimoeModelTest.test_assisted_decoding_with_num_logits_to_keep with phimoe->mixtral, Phimoe->Mixtral
@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest(
"Mixtral doesn't support StaticCache, please check the following issue -> https://github.com/huggingface/transformers/issues/28981."
)
pass

# Ignore copy
def test_load_balancing_loss(self):
r"""
Expand Down
9 changes: 9 additions & 0 deletions tests/models/moshi/test_modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,15 @@ def test_disk_offload_safetensors(self):
def test_save_load(self):
super().test_save_load()

# Copied from tests.models.phimoe.test_modeling_phimoe.PhimoeModelTest.test_assisted_decoding_with_num_logits_to_keep with phimoe->moshi, Phimoe->Moshi
@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest(
"Moshi decoder doesn't support StaticCache, please check the following issue -> https://github.com/huggingface/transformers/issues/28981."
)
pass


class MoshiTester:
def __init__(
Expand Down
9 changes: 9 additions & 0 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,15 @@ def test_model_rope_scaling_short_long_factor(self, scaling_type):
# Last token generated using long factor
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))

# Copied from tests.models.phimoe.test_modeling_phimoe.PhimoeModelTest.test_assisted_decoding_with_num_logits_to_keep with phimoe->phi3, Phimoe->Phi3
@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest(
"Phi3 doesn't support StaticCache, please check the following issue -> https://github.com/huggingface/transformers/issues/28981."
)
pass


@slow
@require_torch
Expand Down
8 changes: 8 additions & 0 deletions tests/models/phimoe/test_modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,14 @@ def test_model_rope_scaling_short_long_factor(self, scaling_type):
# Last token generated using long factor
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))

@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest(
"Phimoe doesn't support StaticCache, please check the following issue -> https://github.com/huggingface/transformers/issues/28981."
)
pass


@slow
@require_torch
Expand Down