Skip to content

Commit

Permalink
Refactor CPU llama inference code (huggingface#728)
Browse files Browse the repository at this point in the history
* ipex 2.3 released

* refactor IPEXLlamaAttention

* change to Ref

* remove Ref

* skip tests

* skip tests

* skip testing without pkv

* add tests skip

* only llama2 with at least 64 head size support IAKV

* cannot assert same outputs cause do_sample=True

* rm tiny-llama model testing cause it not work for IAKV

* fix code style

* refine docstring

* fix duplicted code

* refactor attention forward

* add use_cache for rope

* use with and without cache

* refine code

* add reference link

* bug fix

* use reshape

* Apply suggestions from code review

Co-authored-by: Ella Charlaix <[email protected]>

* fix

---------

Co-authored-by: jiqing-feng <[email protected]>
Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2024
1 parent 1ab78d5 commit 36e5b23
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 152 deletions.
33 changes: 12 additions & 21 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,26 @@
# limitations under the License.

from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXLlamaDecoderLayerRef,
_llama_attn_forward,
_IPEXLlamaDecoderLayer,
_llama_layer_norm_forward,
_llama_model_forward,
)


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.41.2"

_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
_IPEX_EXPORTED_TASK = ("text-generation",)

Expand Down Expand Up @@ -64,27 +66,16 @@ def patch_op(m, target_m, new_op_name, new_op):

def _patch_llama_model(model):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching")
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
">", _TRANSFORMERS_MAX_VERSION
):
raise ImportError(
f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports RotaryEmbedding and IndirectAccessKVCacheAttention"
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)

from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
return model


Expand Down
Loading

0 comments on commit 36e5b23

Please sign in to comment.