-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* llama model patcher * fix jit model * fix jit model * rm autocast in model * add llama model patcher * support assisted decoding and add reorder cache function * add comment for _prepare_past_key_values * rebase main * fix model_dtype * rm useless comments * fix llama * add comments for ipex_rope and ipex_scale_dot_product * fix comments * add enable_tpp comments * fix import * fix review aroun2 * add torch.no_grad to avoid auto_kernel_selection issue * use torch.no_grad in jit trace * fix ipex model testing * add tests for ipex model generation with multi inputs * fix code style * remove __get__(self) as _reorder_cache is static method for the class * fix reorder_cache * use model_type * check if reorder_cache is a static method * fix _reorder_cache * fix raise import error * test ipex patching * fix comments * update API name and testing * disable untill ipex version 2.5.0 * update testing name * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <[email protected]> * Update tests/ipex/test_modeling.py Co-authored-by: Ella Charlaix <[email protected]> * fix tests --------- Co-authored-by: Ella Charlaix <[email protected]>
- Loading branch information
1 parent
c356aa3
commit 6e8cd3d
Showing
5 changed files
with
558 additions
and
14 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from transformers.models.llama.modeling_llama import ( | ||
LlamaAttention, | ||
LlamaDecoderLayer, | ||
LlamaForCausalLM, | ||
LlamaModel, | ||
LlamaRMSNorm, | ||
) | ||
|
||
from optimum.intel.utils.import_utils import is_ipex_version | ||
|
||
from .modeling_utils import ( | ||
_IPEXLlamaDecoderLayerRef, | ||
_llama_attn_forward, | ||
_llama_layer_norm_forward, | ||
_llama_model_forward, | ||
) | ||
|
||
|
||
_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) | ||
_IPEX_EXPORTED_TASK = ("text-generation",) | ||
|
||
|
||
def convert_func(m, func_name, new_function): | ||
bound_method = new_function.__get__(m, m.__class__) | ||
setattr(m, func_name, bound_method) | ||
|
||
|
||
def convert_functions(m, target_m, new_function_name, new_function): | ||
for _, sub_m in m.named_children(): | ||
if isinstance(sub_m, target_m): | ||
convert_func(sub_m, new_function_name, new_function) | ||
convert_functions(sub_m, target_m, new_function_name, new_function) | ||
|
||
|
||
def convert_class(m, target_m, new_class, config, distributed=False): | ||
for name, sub_m in m.named_children(): | ||
if isinstance(sub_m, target_m): | ||
new_m = new_class(sub_m, config, distributed) | ||
setattr(m, name, new_m) | ||
convert_class(sub_m, target_m, new_class, config, distributed) | ||
|
||
|
||
def patch_op(m, target_m, new_op_name, new_op): | ||
for name, sub_m in m.named_children(): | ||
if isinstance(sub_m, target_m): | ||
setattr(sub_m, new_op_name, new_op) | ||
patch_op(sub_m, target_m, new_op_name, new_op) | ||
|
||
|
||
def _patch_llama_model(model): | ||
if is_ipex_version("<", "2.5.0"): | ||
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") | ||
|
||
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding | ||
|
||
ipex_rope = RotaryEmbedding( | ||
model.config.max_position_embeddings, | ||
model.config.hidden_size // model.config.num_attention_heads, | ||
model.config.rope_theta, | ||
model.config.architectures[0], | ||
) | ||
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) | ||
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) | ||
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) | ||
|
||
convert_functions(model, LlamaModel, "forward", _llama_model_forward) | ||
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) | ||
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) | ||
|
||
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) | ||
return model | ||
|
||
|
||
def _patch_model(model): | ||
if isinstance(model, LlamaForCausalLM): | ||
model = _patch_llama_model(model) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,307 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask | ||
from transformers.modeling_outputs import BaseModelOutputWithPast | ||
from transformers.models.llama.modeling_llama import repeat_kv | ||
|
||
from optimum.intel.utils.import_utils import is_ipex_version | ||
|
||
|
||
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 | ||
def _llama_layer_norm_forward(self, hidden_states): | ||
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) | ||
|
||
|
||
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 | ||
def _llama_attn_forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
**kwargs, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
bsz, q_len, _ = hidden_states.size() | ||
|
||
query = self.q_proj(hidden_states) | ||
key = self.k_proj(hidden_states) | ||
value = self.v_proj(hidden_states) | ||
|
||
kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len | ||
|
||
query = query.view(bsz, q_len, self.num_heads, self.head_dim) | ||
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | ||
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | ||
# Use ipex op to rotary position embedding more efficient. | ||
key = self.ipex_rope( | ||
key, | ||
position_ids, | ||
self.num_key_value_heads, | ||
self.head_dim, | ||
self.head_dim // 2, | ||
self.head_dim, | ||
kv_seq_len, | ||
) | ||
query = self.ipex_rope( | ||
query, | ||
position_ids, | ||
self.num_heads, | ||
self.head_dim, | ||
self.head_dim // 2, | ||
self.head_dim, | ||
kv_seq_len, | ||
) | ||
|
||
if use_cache: | ||
# This ipex op pre-allocates buffers for past_key_values and use beam index history | ||
# which to decide which beam should be used to make attention scale dot more efficient. | ||
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( | ||
query, | ||
key, | ||
value, | ||
math.sqrt(self.head_dim), | ||
past_key_value, | ||
None, | ||
attention_mask, | ||
) | ||
else: | ||
value_states = value.transpose(1, 2) | ||
query_states = query.transpose(1, 2) | ||
key_states = key.transpose(1, 2) | ||
kv_seq_len = key_states.shape[-2] | ||
|
||
past_key_value = None | ||
# repeat k/v heads if n_kv_heads < n_heads | ||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
|
||
if attention_mask is not None: | ||
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) | ||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) | ||
|
||
# upcast attention to fp32 | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||
attn_output = torch.matmul(attn_weights, value_states) | ||
|
||
attn_output = attn_output.transpose(1, 2) | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, past_key_value | ||
|
||
|
||
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 | ||
def _llama_model_forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
**kwargs, | ||
) -> Union[Tuple, BaseModelOutputWithPast]: | ||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# retrieve input_ids and inputs_embeds | ||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
elif input_ids is not None: | ||
batch_size, seq_length = input_ids.shape[:2] | ||
elif inputs_embeds is not None: | ||
batch_size, seq_length = inputs_embeds.shape[:2] | ||
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
past_key_values_length = 0 | ||
if past_key_values is not None: | ||
past_key_values_length = past_key_values[0][0].shape[2] | ||
|
||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
position_ids = torch.arange( | ||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device | ||
) | ||
position_ids = position_ids.unsqueeze(0) | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.embed_tokens(input_ids) | ||
|
||
if getattr(self.config, "_flash_attn_2_enabled", False): | ||
# 2d mask is passed through the layers | ||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
else: | ||
# 4d mask is passed through the layers | ||
attention_mask = _prepare_4d_causal_attention_mask( | ||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length | ||
) | ||
|
||
# embed positions | ||
hidden_states = inputs_embeds | ||
|
||
# decoder layers | ||
all_hidden_states = () if output_hidden_states else None | ||
all_self_attns = () if output_attentions else None | ||
next_decoder_cache = () if use_cache else None | ||
|
||
for idx, decoder_layer in enumerate(self.layers): | ||
if output_hidden_states: | ||
all_hidden_states += (hidden_states,) | ||
|
||
past_key_value = past_key_values[idx] if past_key_values is not None else None | ||
|
||
layer_outputs = decoder_layer( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_value=past_key_value, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
) | ||
|
||
hidden_states = layer_outputs[0] | ||
|
||
if use_cache: | ||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | ||
|
||
if output_attentions: | ||
all_self_attns += (layer_outputs[1],) | ||
|
||
hidden_states = self.norm(hidden_states) | ||
|
||
# add hidden states from the last decoder layer | ||
if output_hidden_states: | ||
all_hidden_states += (hidden_states,) | ||
|
||
next_cache = next_decoder_cache if use_cache else None | ||
if not return_dict: | ||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||
return BaseModelOutputWithPast( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_cache, | ||
hidden_states=all_hidden_states, | ||
attentions=all_self_attns, | ||
) | ||
|
||
|
||
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 | ||
class _IPEXLlamaDecoderLayerRef(nn.Module): | ||
def __init__(self, module, config, distributed=False): | ||
if is_ipex_version("<", "2.5.0"): | ||
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") | ||
|
||
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd | ||
|
||
super().__init__() | ||
for k, v in module.__dict__.items(): | ||
setattr(self, k, v) | ||
for k, v in module.__class__.__dict__.items(): | ||
if k.startswith("__") or k.startswith("forward"): | ||
continue | ||
setattr(self.__class__, k, getattr(module.__class__, k)) | ||
self.distributed = distributed | ||
if not self.distributed: | ||
self.mha_linear_add = LinearAdd(module.self_attn.o_proj) | ||
self.mlp_linear_add = LinearAdd(module.mlp.down_proj) | ||
del self.__dict__["_modules"]["self_attn"].o_proj | ||
del self.__dict__["_modules"]["mlp"].down_proj | ||
self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj) | ||
del self.__dict__["_modules"]["mlp"].gate_proj | ||
del self.__dict__["_modules"]["mlp"].up_proj | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: Optional[bool] = False, | ||
use_cache: Optional[bool] = False, | ||
**kwargs, | ||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
""" | ||
Args: | ||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | ||
attention_mask (`torch.FloatTensor`, *optional*): | ||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, | ||
query_sequence_length, key_sequence_length)` if default attention is used. | ||
output_attentions (`bool`, *optional*): | ||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | ||
returned tensors for more detail. | ||
use_cache (`bool`, *optional*): | ||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | ||
(see `past_key_values`). | ||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | ||
""" | ||
|
||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
|
||
# Self Attention | ||
hidden_states, self_attn_weights, present_key_value = self.self_attn( | ||
hidden_states=hidden_states, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_value=past_key_value, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
) | ||
if not self.distributed: | ||
hidden_states = self.mha_linear_add(hidden_states, residual) | ||
else: | ||
hidden_states = self.self_attn.o_proj(hidden_states) | ||
hidden_states = residual + hidden_states | ||
|
||
# Fully Connected | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
|
||
mlp_gate = self.linear_silu_mul(hidden_states) | ||
|
||
if not self.distributed: | ||
hidden_states = self.mlp_linear_add(mlp_gate, residual) | ||
else: | ||
hidden_states = self.mlp.down_proj(mlp_gate) | ||
hidden_states = residual + hidden_states | ||
|
||
outputs = (hidden_states,) | ||
|
||
if output_attentions: | ||
outputs += (self_attn_weights,) | ||
|
||
if use_cache: | ||
outputs += (present_key_value,) | ||
|
||
return outputs |
Oops, something went wrong.