Skip to content

Commit

Permalink
Add torch.compile for Mistral (#30642)
Browse files Browse the repository at this point in the history
* first version

* fix sliding window

* fix style

* add sliding window cache

* fix style

* address comments

* fix test

* fix style

* move sliding window check inside cache init

* revert changes on irrelevant files & add comment on SlidingWindowCache

* address comments & fix style

fix style

* update causal mask

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] llama

* [run-slow] mistral

* [run-slow] mistral

* [run-slow] mistral

* revert CI from a10 to t4

* wrap up
  • Loading branch information
zhenglongjiepheonix authored May 20, 2024
1 parent 92d1d97 commit 616bb11
Show file tree
Hide file tree
Showing 19 changed files with 511 additions and 236 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.

> [!WARNING]
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile.
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and torch.compile. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.

Expand Down
121 changes: 121 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,124 @@ def reset(self):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()


class SlidingWindowCache(Cache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)

super().__init__()
self.max_batch_size = max_batch_size
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
# when we do short-sentence generation
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.model_sliding_window_size = config.sliding_window
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

cache_shape = (
config.num_hidden_layers,
max_batch_size,
self.num_key_value_heads,
self.sliding_window_size,
self.head_dim,
)

self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)

torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

# assume this only happens in prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > self.sliding_window_size:
k_out = key_states[:, :, -self.sliding_window_size :, :]
v_out = value_states[:, :, -self.sliding_window_size :, :]
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states

slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
to_shift = cache_position >= self.sliding_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size

k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0

def get_max_length(self) -> Optional[int]:
# in theory there is no limit because the sliding window size is fixed
# no matter how long the sentence is
return None

def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()
53 changes: 30 additions & 23 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -96,9 +96,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}


@dataclass
Expand Down Expand Up @@ -1326,33 +1324,42 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
"""
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting static cache object.
Returns the resulting cache object.
"""
needs_new_cache = (
not hasattr(self, "_static_cache")
or self._static_cache.max_batch_size < max_batch_size
or self._static_cache.max_cache_len < max_cache_len
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
)
if needs_new_cache:
if cache_implementation == "sliding_window":
need_new_cache = need_new_cache or (
self._cache.sliding_window_size < self._cache.model_sliding_window_size
and max_cache_len > self._cache.max_cache_len
)
elif cache_implementation == "static":
need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len

if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
self._cache = cache_cls(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache
self._cache.reset()
return self._cache

def _prepare_special_tokens(
self,
Expand Down Expand Up @@ -1615,14 +1622,14 @@ def generate(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)

if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

# 7. determine generation mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down Expand Up @@ -123,7 +123,7 @@ def _get_unpad_data(attention_mask):
)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon
class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -265,8 +264,8 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def attention_mask_func(attention_scores, ltor_mask):


class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down Expand Up @@ -614,7 +614,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):

# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -402,7 +401,6 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand Down
Loading

0 comments on commit 616bb11

Please sign in to comment.