Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model patcher #567

Merged
merged 41 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6c09841
llama model patcher
jiqing-feng Feb 18, 2024
8749a5a
fix jit model
jiqing-feng Feb 18, 2024
e05557a
fix jit model
jiqing-feng Feb 18, 2024
151712d
rm autocast in model
jiqing-feng Feb 18, 2024
c81a5f8
add llama model patcher
jiqing-feng Feb 19, 2024
1782a50
support assisted decoding and add reorder cache function
jiqing-feng Feb 21, 2024
c0c9f5b
Merge branch 'main' into model_patcher
jiqing-feng Feb 26, 2024
41bf0f5
add comment for _prepare_past_key_values
jiqing-feng Feb 26, 2024
6509035
Merge branch 'main' into jit
jiqing-feng Feb 26, 2024
dd63ee7
rebase main
jiqing-feng Feb 26, 2024
16706d3
fix model_dtype
jiqing-feng Feb 26, 2024
1244772
rm useless comments
jiqing-feng Feb 26, 2024
daabe80
merge jit branch
jiqing-feng Feb 26, 2024
4c1c636
fix llama
jiqing-feng Feb 26, 2024
b04b435
add comments for ipex_rope and ipex_scale_dot_product
jiqing-feng Feb 26, 2024
38ed051
fix comments
jiqing-feng Feb 28, 2024
0dbde50
add enable_tpp comments
jiqing-feng Feb 28, 2024
e5b7afd
fix import
jiqing-feng Feb 28, 2024
eb6ab6a
fix review aroun2
jiqing-feng Mar 1, 2024
41ca8c4
add torch.no_grad to avoid auto_kernel_selection issue
jiqing-feng Mar 1, 2024
7b67c1f
use torch.no_grad in jit trace
jiqing-feng Mar 1, 2024
f1598a9
fix ipex model testing
jiqing-feng Mar 1, 2024
eeac729
add tests for ipex model generation with multi inputs
jiqing-feng Mar 1, 2024
5acfac8
fix code style
jiqing-feng Mar 1, 2024
fea73d3
remove __get__(self) as _reorder_cache is static method for the class
jiqing-feng Mar 1, 2024
2e45bcc
fix reorder_cache
jiqing-feng Mar 1, 2024
1ef2c96
use model_type
jiqing-feng Mar 1, 2024
de2f468
check if reorder_cache is a static method
jiqing-feng Mar 1, 2024
31d42a5
fix _reorder_cache
jiqing-feng Mar 1, 2024
809542c
Merge branch 'huggingface:main' into model_patcher
jiqing-feng Mar 1, 2024
3a86c40
fix raise import error
jiqing-feng Mar 4, 2024
e03259c
test ipex patching
jiqing-feng Mar 4, 2024
4c3335b
fix comments
jiqing-feng Mar 6, 2024
b1f704a
Merge branch 'huggingface:main' into model_patcher
jiqing-feng Mar 6, 2024
4e0ec0a
Merge branch 'huggingface:main' into model_patcher
jiqing-feng Mar 7, 2024
aa3008f
update API name and testing
jiqing-feng Mar 7, 2024
37e8cc4
disable untill ipex version 2.5.0
jiqing-feng Mar 7, 2024
e3a7024
update testing name
jiqing-feng Mar 7, 2024
070a0dc
Update optimum/intel/ipex/modeling_base.py
jiqing-feng Mar 8, 2024
8c60c7a
Update tests/ipex/test_modeling.py
jiqing-feng Mar 8, 2024
f8d3f74
fix tests
jiqing-feng Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
91 changes: 91 additions & 0 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)

from optimum.intel.utils.import_utils import is_ipex_version

from .modeling_utils import (
_IPEXLlamaDecoderLayerRef,
llama_attn_forward,
llama_layer_norm_forward,
llama_model_forward,
)


_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
_IPEX_EXPORTED_TASK = ("text-generation",)


def convert_func(m, func_name, new_function):
bound_method = new_function.__get__(m, m.__class__)
setattr(m, func_name, bound_method)


def convert_functions(m, target_m, new_function_name, new_function):
for _, sub_m in m.named_children():
if isinstance(sub_m, target_m):
convert_func(sub_m, new_function_name, new_function)
convert_functions(sub_m, target_m, new_function_name, new_function)


def convert_class(m, target_m, new_class, config, distributed=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, distributed)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config, distributed)


def patch_op(m, target_m, new_op_name, new_op):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
setattr(sub_m, new_op_name, new_op)
patch_op(sub_m, target_m, new_op_name, new_op)


def _patch_llama_model(model):
if is_ipex_version("<=", "2.3.0"):
raise ImportError("Only ipex version > 2.3.0 supports ApplyRotaryEmbedding and IndirectAccessKVCache")

from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectAccessKVCache

ipex_rope = ApplyRotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

convert_functions(model, LlamaModel, "forward", llama_model_forward)
convert_functions(model, LlamaAttention, "forward", llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", llama_layer_norm_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
return model


def _patch_model(model):
if isinstance(model, LlamaForCausalLM):
model = _patch_llama_model(model)
return model
Loading
Loading