Skip to content

Commit

Permalink
add sliding window cache
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed May 9, 2024
1 parent 5064b84 commit 19bd6cd
Show file tree
Hide file tree
Showing 20 changed files with 338 additions and 59 deletions.
19 changes: 11 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,15 @@ def run(self):
extras["sklearn"] = deps_list("scikit-learn")

extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp")
extras["tf-cpu"] = deps_list("keras", "tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp", "tensorflow-probability")
extras["tf-cpu"] = deps_list(
"keras",
"tensorflow-cpu",
"onnxconverter-common",
"tf2onnx",
"tensorflow-text",
"keras-nlp",
"tensorflow-probability",
)

extras["torch"] = deps_list("torch", "accelerate")
extras["accelerate"] = deps_list("accelerate")
Expand Down Expand Up @@ -380,12 +388,7 @@ def run(self):
+ extras["tf-speech"]
)
extras["dev"] = (
extras["all"]
+ extras["testing"]
+ extras["quality"]
+ extras["ja"]
+ extras["sklearn"]
+ extras["modelcreation"]
extras["all"] + extras["testing"] + extras["quality"] + extras["ja"] + extras["sklearn"] + extras["modelcreation"]
)

extras["torchhub"] = deps_list(
Expand Down Expand Up @@ -470,4 +473,4 @@ def run(self):
extras["tests_examples_tf"] = deps_list()
extras["tests_custom_tokenizers"] = deps_list()
extras["tests_exotic_models"] = deps_list()
extras["consistency"] = deps_list()
extras["consistency"] = deps_list()
112 changes: 112 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,115 @@ def reset(self):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()


class SlidingWindowCache(Cache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)

self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)

torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Dict[str, Any] | None = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size

k_out, v_out = k_out, v_out
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out

return k_out, v_out

def get_seq_length(self, layer_idx: int | None = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0

def get_max_length(self) -> int | None:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None

def need_new_cache(self, max_batch_size: int, new_max_cache_len: int) -> bool:
# this is used by model.generate, when we reuse model between generations,
# we need to be careful because the new `max_cache_len` may become
# larger and `self.sliding_window_size` might change accordingly
return max_batch_size > self.max_batch_size or (
self.sliding_window_size < self.model_sliding_window_size and new_max_cache_len > self.max_cache_len
)

def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
43 changes: 39 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -94,9 +94,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}


@dataclass
Expand Down Expand Up @@ -1342,6 +1340,33 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache

# maybe a better way is to use a single factory function to set caches for models ?
def _get_sliding_window_cache(self, max_batch_size: int, max_cache_len: int) -> SlidingWindowCache:
"""
Sets a sliding window cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting sliding window cache object.
"""
needs_new_cache = not hasattr(self, "_sliding_window_cache") or self._sliding_window_cache.need_new_cache(
max_batch_size, max_cache_len
)
if needs_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._sliding_window_cache = SlidingWindowCache(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
else:
self._sliding_window_cache.reset() # reset the cache for a new generation
return self._sliding_window_cache

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1557,6 +1582,16 @@ def generate(
)
if generation_config.cache_implementation == "static":
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
elif generation_config.cache_implementation == "sliding_window":
if not hasattr(self.config, "sliding_window") or self.config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
model_kwargs["past_key_values"] = self._get_sliding_window_cache(
batch_size, generation_config.max_length
)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
# TODO @longjie no longer copied from Mistral after static cache
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -154,7 +155,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down Expand Up @@ -126,7 +127,8 @@ def _get_unpad_data(attention_mask):
)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
# TODO @longjie no longer copied from Mistral after static cache
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ def attention_mask_func(attention_scores, ltor_mask):


class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# TODO @longjie no longer copied from Mistral after static cache
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down Expand Up @@ -617,7 +618,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):

# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# TODO @longjie no longer copied from Mistral after static cache
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# TODO @longjie no longer copied from Mistral after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
Loading

0 comments on commit 19bd6cd

Please sign in to comment.