Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model patcher #567

Merged
merged 41 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6c09841
llama model patcher
jiqing-feng Feb 18, 2024
8749a5a
fix jit model
jiqing-feng Feb 18, 2024
e05557a
fix jit model
jiqing-feng Feb 18, 2024
151712d
rm autocast in model
jiqing-feng Feb 18, 2024
c81a5f8
add llama model patcher
jiqing-feng Feb 19, 2024
1782a50
support assisted decoding and add reorder cache function
jiqing-feng Feb 21, 2024
c0c9f5b
Merge branch 'main' into model_patcher
jiqing-feng Feb 26, 2024
41bf0f5
add comment for _prepare_past_key_values
jiqing-feng Feb 26, 2024
6509035
Merge branch 'main' into jit
jiqing-feng Feb 26, 2024
dd63ee7
rebase main
jiqing-feng Feb 26, 2024
16706d3
fix model_dtype
jiqing-feng Feb 26, 2024
1244772
rm useless comments
jiqing-feng Feb 26, 2024
daabe80
merge jit branch
jiqing-feng Feb 26, 2024
4c1c636
fix llama
jiqing-feng Feb 26, 2024
b04b435
add comments for ipex_rope and ipex_scale_dot_product
jiqing-feng Feb 26, 2024
38ed051
fix comments
jiqing-feng Feb 28, 2024
0dbde50
add enable_tpp comments
jiqing-feng Feb 28, 2024
e5b7afd
fix import
jiqing-feng Feb 28, 2024
eb6ab6a
fix review aroun2
jiqing-feng Mar 1, 2024
41ca8c4
add torch.no_grad to avoid auto_kernel_selection issue
jiqing-feng Mar 1, 2024
7b67c1f
use torch.no_grad in jit trace
jiqing-feng Mar 1, 2024
f1598a9
fix ipex model testing
jiqing-feng Mar 1, 2024
eeac729
add tests for ipex model generation with multi inputs
jiqing-feng Mar 1, 2024
5acfac8
fix code style
jiqing-feng Mar 1, 2024
fea73d3
remove __get__(self) as _reorder_cache is static method for the class
jiqing-feng Mar 1, 2024
2e45bcc
fix reorder_cache
jiqing-feng Mar 1, 2024
1ef2c96
use model_type
jiqing-feng Mar 1, 2024
de2f468
check if reorder_cache is a static method
jiqing-feng Mar 1, 2024
31d42a5
fix _reorder_cache
jiqing-feng Mar 1, 2024
809542c
Merge branch 'huggingface:main' into model_patcher
jiqing-feng Mar 1, 2024
3a86c40
fix raise import error
jiqing-feng Mar 4, 2024
e03259c
test ipex patching
jiqing-feng Mar 4, 2024
4c3335b
fix comments
jiqing-feng Mar 6, 2024
b1f704a
Merge branch 'huggingface:main' into model_patcher
jiqing-feng Mar 6, 2024
4e0ec0a
Merge branch 'huggingface:main' into model_patcher
jiqing-feng Mar 7, 2024
aa3008f
update API name and testing
jiqing-feng Mar 7, 2024
37e8cc4
disable untill ipex version 2.5.0
jiqing-feng Mar 7, 2024
e3a7024
update testing name
jiqing-feng Mar 7, 2024
070a0dc
Update optimum/intel/ipex/modeling_base.py
jiqing-feng Mar 8, 2024
8c60c7a
Update tests/ipex/test_modeling.py
jiqing-feng Mar 8, 2024
f8d3f74
fix tests
jiqing-feng Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/exporters/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model_patcher import export_model
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
337 changes: 337 additions & 0 deletions optimum/exporters/ipex/llama_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
import math
from typing import List, Optional, Tuple, Union

import torch
from intel_extension_for_pytorch.llm.modules import linear2SiluMul, linearAdd
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import repeat_kv


def llama_layer_norm_forward(self, hidden_states):
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)


def llama_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)

kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len

query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if use_cache:
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)
else:
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
kv_seq_len = key_states.shape[-2]

past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


def prepare_inputs_for_generation(
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs


def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# embed positions
hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


class _IPEXLlamaDecoderLayerRef(nn.Module):
def __init__(self, module, config, distributed=False):
super().__init__()
for k, v in module.__dict__.items():
setattr(self, k, v)
for k, v in module.__class__.__dict__.items():
if k.startswith("__") or k.startswith("forward"):
continue
setattr(self.__class__, k, getattr(module.__class__, k))
self.distributed = distributed
if not self.distributed:
self.mha_linear_add = linearAdd(module.self_attn.o_proj)
self.mlp_linear_add = linearAdd(module.mlp.down_proj)
del self.__dict__["_modules"]["self_attn"].o_proj
del self.__dict__["_modules"]["mlp"].down_proj
self.linear_silu_mul = linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj)
del self.__dict__["_modules"]["mlp"].gate_proj
del self.__dict__["_modules"]["mlp"].up_proj

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
if not self.distributed:
hidden_states = self.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

mlp_gate = self.linear_silu_mul(hidden_states)

if not self.distributed:
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

return outputs
Loading
Loading