Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed May 9, 2024
1 parent 857ddfa commit 5064b84
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To optimize this, you can use a kv-cache to store the past keys and values inste
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with torch.compile for up to a 4x speed up.

> [!WARNING]
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma) and [Llama](./model_doc/llama2) models support static kv-cache and torch.compile.
> Currently, only [Command R](./model_doc/cohere), [Gemma](./model_doc/gemma), [Llama](./model_doc/llama2) and [Mistral](./model_doc/mistral.md) models support static kv-cache and torch.compile.
For this example, let's load the [Gemma](https://hf.co/google/gemma-2b) model.

Expand Down
89 changes: 43 additions & 46 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mistral model."""
"""PyTorch Mistral model."""

import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -88,8 +88,7 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -123,7 +122,7 @@ def cos_cached(self):
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._cos_cached

@torch.no_grad
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -149,8 +148,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# TODO @Arthur no longer copied from LLama after static cache
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand Down Expand Up @@ -261,12 +259,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand All @@ -278,12 +271,12 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

past_key_value = getattr(self, "past_key_value", past_key_value)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position" : cache_position} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
Expand Down Expand Up @@ -342,7 +335,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
cache_position: Optional[torch.LongTensor] = None,
):
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand All @@ -365,7 +358,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length()
kv_seq_len += cache_position[0]

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -605,8 +598,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)


# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
class MistralSdpaAttention(MistralAttention):
"""
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
Expand Down Expand Up @@ -669,7 +661,6 @@ def forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]


# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
Expand All @@ -680,14 +671,14 @@ def forward(
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -741,11 +732,6 @@ def forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""

if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -954,7 +940,7 @@ def forward(
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
Expand All @@ -972,13 +958,14 @@ def forward(
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, use_cache
)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, use_cache)

hidden_states = inputs_embeds

# decoder layers
Expand Down Expand Up @@ -1041,21 +1028,20 @@ def forward(
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

# copied from Llama implementation

def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
use_cache: bool
use_cache: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
Expand All @@ -1065,9 +1051,10 @@ def _update_causal_mask(
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask: return attention_mask
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
Expand All @@ -1076,8 +1063,11 @@ def _update_causal_mask(
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window, is_training=self.training,
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None

Expand All @@ -1096,8 +1086,10 @@ def _update_causal_mask(
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() != 4):
exclude_mask |= torch.arange(target_length, device=device) < (cache_position.reshape(-1,1) - self.config.sliding_window)
if self.config.sliding_window is not None and (attention_mask is None or attention_mask.dim() == 2):
exclude_mask |= torch.arange(target_length, device=device) < (
cache_position.reshape(-1, 1) - self.config.sliding_window
)
causal_mask *= exclude_mask

causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
Expand All @@ -1120,9 +1112,9 @@ def _update_causal_mask(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -1258,6 +1250,7 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
Expand All @@ -1272,16 +1265,14 @@ def prepare_inputs_for_generation(
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
# cache_length = past_key_values.get_seq_length()
# past_length = past_key_values.seen_tokens
# max_cache_length = past_key_values.get_max_length()
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
Expand Down Expand Up @@ -1320,6 +1311,12 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids.contiguous()}

input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]

model_inputs.update(
{
"position_ids": position_ids,
Expand Down
23 changes: 14 additions & 9 deletions tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Mistral model. """

"""Testing suite for the PyTorch Mistral model."""

import gc
import tempfile
Expand Down Expand Up @@ -471,13 +470,14 @@ def test_flash_attn_2_generate_use_cache(self):
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Mistral flash attention does not support right padding")
# copied from Llama tests to supress errors for now

# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_new_cache_format
@unittest.skip("TODO @gante fix this for Mistral")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch_gpu
class MistralIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
Expand Down Expand Up @@ -634,7 +634,7 @@ def test_speculative_generation(self):
del model
backend_empty_cache(torch_device)
gc.collect()

@slow
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
Expand All @@ -644,9 +644,14 @@ def test_compile_static_cache(self):

NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
8: ['My favourite condiment is 100% ketchup. I love it on everything. '
'I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles'],
# 7: [],
8: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
}

prompts = ["My favourite condiment is "]
Expand Down Expand Up @@ -675,4 +680,4 @@ def test_compile_static_cache(self):
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)

0 comments on commit 5064b84

Please sign in to comment.