Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable Real Varlen Attention #1065

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
transformers-version: ["4.46.0", "4.46.3"]
torch-version: ["2.4.0", "2.5.*"]
torch-version: ["2.5.*"]

runs-on: ubuntu-22.04

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.block_size = 64
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -27,6 +27,7 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
Expand Down Expand Up @@ -111,6 +112,7 @@ def _patch_gpt2_model(model):
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
return model


Expand Down
174 changes: 142 additions & 32 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import torch
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions

from optimum.intel.utils.import_utils import is_ipex_version
Expand All @@ -29,20 +32,21 @@

logger = logging.getLogger(__name__)

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0"


if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
logger.warning(
f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model."
)
else:
from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention
from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding
from intel_extension_for_pytorch.llm.modules import (
Linear2SiluMul,
LinearAdd,
LinearAddAdd,
LinearGelu,
LinearNewGelu,
PagedAttention,
)

Expand Down Expand Up @@ -194,7 +198,10 @@ def _llama_model_forward(
next_decoder_cache = () if use_cache else None

position_embeddings = self.rotary_emb(hidden_states, position_ids)
if past_key_values_length == 0:

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand All @@ -207,7 +214,13 @@ def _llama_model_forward(
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
Expand Down Expand Up @@ -309,7 +322,9 @@ def _falcon_model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

if past_key_values_length == 0:
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
Expand All @@ -321,7 +336,14 @@ def _falcon_model_forward(
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)

next_decoder_cache = None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -436,15 +458,23 @@ def _gpt2_model_forward(

hidden_states = self.drop(hidden_states)

if past_length == 0:
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)

if past_length == 0 and past_key_values is not None:
# first token, remove the padding from hidden_states, varlen do not accept attention mask
hidden_states_copy = hidden_states
index = attention_mask.view(-1) != 0
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)

presents = None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -528,7 +558,10 @@ def _gpt2_block_forward(
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
if hasattr(self.attn, "linear_add"):
hidden_states = self.attn.linear_add(attn_output, residual)
else:
hidden_states = attn_output + residual

if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
Expand Down Expand Up @@ -557,7 +590,10 @@ def _gpt2_block_forward(
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
if hasattr(self.mlp, "linear_add"):
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
else:
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
Expand All @@ -577,6 +613,7 @@ def __init__(self, module, config) -> None:
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
).repeat_interleave(self.num_groups)
self.use_sdpa = False

def qkv_gemm(self, hidden_states):
raise NotImplementedError("Need to implement in specific model class")
Expand All @@ -585,9 +622,32 @@ def rope(self, *args, **kwargs):
raise NotImplementedError("Need to implement in specific model class")

def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
return attn_output

def varlen_attn(self, query, key, value, past_key_value, input_lens):
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
key,
value,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
None,
)

return attn_output

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -610,28 +670,28 @@ def forward(
if past_key_value is not None:
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)

attn_output = torch.empty_like(query)
if past_len == 0:
# prefill, remove padding
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
attn_output,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
0.0,
1.0 / math.sqrt(self.head_dim),
False,
True,
False,
None,
)
# prefill
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
)
self.use_sdpa = True
else:
attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens)
else:
# decode
attn_output = torch.empty_like(query)
PagedAttention.single_query_cached_kv_attention(
attn_output,
query,
Expand Down Expand Up @@ -720,9 +780,23 @@ class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, config)
_setattr_from_module(self, module)
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
self.c_attn_linear.bias = self.c_attn.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
Expand All @@ -732,9 +806,11 @@ def rope(self, query, key, *args, **kwargs):
return query, key

def postprocess_attention_output(self, attn_output):
if self.use_sdpa:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
if not hasattr(self, "linear_add"):
attn_output = self.c_proj(attn_output)
return attn_output


Expand Down Expand Up @@ -805,6 +881,40 @@ def forward(
return output


class _IPEXGPT2MLP(nn.Module):
def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
self.c_fc_linear.bias = self.c_fc.bias
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
self.c_proj_linear.bias = self.c_proj.bias
if self.module_device.type == "cpu":
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)

if self.module_device.type == "cpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = LinearAdd(self.c_proj_linear)

elif self.module_device.type == "xpu":
if self.c_proj_linear not in ["LinearAllreduce"]:
self.linear_add = XPULinearAdd(self.c_proj_linear)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
if hasattr(self, "linear_new_gelu"):
hidden_states = self.linear_new_gelu(hidden_states)
else:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
if not hasattr(self, "linear_add"):
hidden_states = self.c_proj(hidden_states)
return hidden_states


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config):
Expand Down
4 changes: 2 additions & 2 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
# Patch functions to support ipex_paged cache
if self._add_patch:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"nncf": ["nncf>=2.14.0"],
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47", "accelerate"],
"ipex": ["intel-extension-for-pytorch>=2.5", "transformers>4.45,<4.47", "accelerate"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
Loading