From 6e8cd3d8134438b7a488d98a48d62607c3185412 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Fri, 8 Mar 2024 18:48:02 +0800 Subject: [PATCH] Add IPEX model patcher (#567) * llama model patcher * fix jit model * fix jit model * rm autocast in model * add llama model patcher * support assisted decoding and add reorder cache function * add comment for _prepare_past_key_values * rebase main * fix model_dtype * rm useless comments * fix llama * add comments for ipex_rope and ipex_scale_dot_product * fix comments * add enable_tpp comments * fix import * fix review aroun2 * add torch.no_grad to avoid auto_kernel_selection issue * use torch.no_grad in jit trace * fix ipex model testing * add tests for ipex model generation with multi inputs * fix code style * remove __get__(self) as _reorder_cache is static method for the class * fix reorder_cache * use model_type * check if reorder_cache is a static method * fix _reorder_cache * fix raise import error * test ipex patching * fix comments * update API name and testing * disable untill ipex version 2.5.0 * update testing name * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update tests/ipex/test_modeling.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix tests --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --- optimum/exporters/ipex/__init__.py | 0 optimum/exporters/ipex/model_patcher.py | 91 +++++++ optimum/exporters/ipex/modeling_utils.py | 307 +++++++++++++++++++++++ optimum/intel/ipex/modeling_base.py | 135 ++++++++-- tests/ipex/test_modeling.py | 39 +++ 5 files changed, 558 insertions(+), 14 deletions(-) create mode 100644 optimum/exporters/ipex/__init__.py create mode 100644 optimum/exporters/ipex/model_patcher.py create mode 100644 optimum/exporters/ipex/modeling_utils.py diff --git a/optimum/exporters/ipex/__init__.py b/optimum/exporters/ipex/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py new file mode 100644 index 0000000000..60ff3b721b --- /dev/null +++ b/optimum/exporters/ipex/model_patcher.py @@ -0,0 +1,91 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) + +from optimum.intel.utils.import_utils import is_ipex_version + +from .modeling_utils import ( + _IPEXLlamaDecoderLayerRef, + _llama_attn_forward, + _llama_layer_norm_forward, + _llama_model_forward, +) + + +_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) +_IPEX_EXPORTED_TASK = ("text-generation",) + + +def convert_func(m, func_name, new_function): + bound_method = new_function.__get__(m, m.__class__) + setattr(m, func_name, bound_method) + + +def convert_functions(m, target_m, new_function_name, new_function): + for _, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + convert_func(sub_m, new_function_name, new_function) + convert_functions(sub_m, target_m, new_function_name, new_function) + + +def convert_class(m, target_m, new_class, config, distributed=False): + for name, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + new_m = new_class(sub_m, config, distributed) + setattr(m, name, new_m) + convert_class(sub_m, target_m, new_class, config, distributed) + + +def patch_op(m, target_m, new_op_name, new_op): + for name, sub_m in m.named_children(): + if isinstance(sub_m, target_m): + setattr(sub_m, new_op_name, new_op) + patch_op(sub_m, target_m, new_op_name, new_op) + + +def _patch_llama_model(model): + if is_ipex_version("<", "2.5.0"): + raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") + + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding + + ipex_rope = RotaryEmbedding( + model.config.max_position_embeddings, + model.config.hidden_size // model.config.num_attention_heads, + model.config.rope_theta, + model.config.architectures[0], + ) + ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) + patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) + patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) + + convert_functions(model, LlamaModel, "forward", _llama_model_forward) + convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) + convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) + + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + return model + + +def _patch_model(model): + if isinstance(model, LlamaForCausalLM): + model = _patch_llama_model(model) + return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py new file mode 100644 index 0000000000..f75e559eaf --- /dev/null +++ b/optimum/exporters/ipex/modeling_utils.py @@ -0,0 +1,307 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import repeat_kv + +from optimum.intel.utils.import_utils import is_ipex_version + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 +def _llama_layer_norm_forward(self, hidden_states): + return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 +def _llama_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len + + query = query.view(bsz, q_len, self.num_heads, self.head_dim) + key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + # Use ipex op to rotary position embedding more efficient. + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + + if use_cache: + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + else: + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + kv_seq_len = key_states.shape[-2] + + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 +def _llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 +class _IPEXLlamaDecoderLayerRef(nn.Module): + def __init__(self, module, config, distributed=False): + if is_ipex_version("<", "2.5.0"): + raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") + + from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + + super().__init__() + for k, v in module.__dict__.items(): + setattr(self, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(self.__class__, k, getattr(module.__class__, k)) + self.distributed = distributed + if not self.distributed: + self.mha_linear_add = LinearAdd(module.self_attn.o_proj) + self.mlp_linear_add = LinearAdd(module.mlp.down_proj) + del self.__dict__["_modules"]["self_attn"].o_proj + del self.__dict__["_modules"]["mlp"].down_proj + self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj) + del self.__dict__["_modules"]["mlp"].gate_proj + del self.__dict__["_modules"]["mlp"].up_proj + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + if not self.distributed: + hidden_states = self.mha_linear_add(hidden_states, residual) + else: + hidden_states = self.self_attn.o_proj(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + mlp_gate = self.linear_silu_mul(hidden_states) + + if not self.distributed: + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.mlp.down_proj(mlp_gate) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 9928977ead..00fe3de115 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -22,6 +22,8 @@ import intel_extension_for_pytorch as ipex import torch from huggingface_hub import hf_hub_download +from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp +from intel_extension_for_pytorch.transformers.optimize import get_dummy_input from transformers import ( AutoConfig, AutoModel, @@ -45,14 +47,63 @@ from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ..generation.modeling import jit_trace, prepare_jit_inputs -from ..utils.import_utils import is_torch_version, is_transformers_version +from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model +from ..generation.modeling import prepare_jit_inputs +from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask logger = logging.getLogger(__name__) +_IPEX_SUPPORT_MODEL_TYPES = ("llama",) + + +def _is_patched_with_ipex(model, task): + if is_ipex_version("<", "2.5.0"): + return False + + if isinstance(model, torch.jit.ScriptModule): + for node in model.graph.nodes(): + # Jit will record the codes position so we can check if the node use ipex exporter. + if "torch_ipex::rotary_position_embedding" in node.__str__(): + return True + return False + else: + return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK + + +def ipex_jit_trace(model, task, use_cache): + # Only support torch version >= 2.1.0 to support example_kwarg_inputs in jit.trace + if is_torch_version("<", "2.1.0"): + raise ImportError("`torch>=2.1.0` is needed to trace your model") + + if _is_patched_with_ipex(model, task): + model = _patch_model(model) + sample_inputs = get_dummy_input(model, return_dict=True) + # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. + _enable_tpp() + else: + model = patch_decoder_attention_mask(model) + sample_inputs = prepare_jit_inputs(model, task, use_cache) + + model.config.return_dict = False + + model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) + with torch.no_grad(): + trace_model = torch.jit.trace( + model, + example_kwarg_inputs=sample_inputs, + strict=False, + check_trace=False, + ) + trace_model = torch.jit.freeze(trace_model) + trace_model(**sample_inputs) + trace_model(**sample_inputs) + + return trace_model + + class IPEXModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" @@ -74,6 +125,7 @@ def __init__( self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model.to(self._device) self.model_save_dir = model_save_dir + self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) self.input_names = { inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" @@ -91,13 +143,13 @@ def _from_transformers( cls, model_id: str, config: PretrainedConfig, + use_cache: bool = True, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, subfolder: str = "", local_files_only: bool = False, - use_cache: bool = True, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, trust_remote_code: bool = False, ): @@ -117,14 +169,13 @@ def _from_transformers( } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - model = patch_decoder_attention_mask(model) - model = ipex.optimize(model, dtype=torch_dtype, level="O1", auto_kernel_selection=True) - traced_model = jit_trace(model, task, use_cache) + traced_model = ipex_jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) config.torchscript = True + config.torch_dtype = torch_dtype return cls._from_pretrained( model_id=save_dir_path, @@ -135,6 +186,7 @@ def _from_transformers( cache_dir=cache_dir, local_files_only=local_files_only, use_cache=use_cache, + model_dtype=torch_dtype, ) @classmethod @@ -213,6 +265,13 @@ def device(self) -> torch.device: def dtype(self) -> torch.dtype: return self._dtype + @property + def model_dtype(self): + logger.warning( + "access to the `model_dtype` attribute is deprecated and will be removed after v1.18.0, please use `_dtype` instead." + ) + return self._dtype + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -223,7 +282,7 @@ def can_generate(self): def _call_model(self, *args, **kwargs): try: - with torch.autocast(self.device.type, self.dtype): + with torch.autocast(self.device.type, self.dtype), torch.no_grad(): out = self.model(*args, **kwargs) except RuntimeError: out = self.model(*args, **kwargs) @@ -232,10 +291,12 @@ def _call_model(self, *args, **kwargs): def _init_warmup(self): # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and # the results of the compute are unpredictable - use_cache = "past_key_values" in self.input_names - dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) - for _ in range(2): - self(**dummy_inputs) + # TODO : add warmup for IPEX exported model + if not self._is_ipex_exported: + use_cache = "past_key_values" in self.input_names + dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + for _ in range(2): + self(**dummy_inputs) class IPEXModelForSequenceClassification(IPEXModel): @@ -334,10 +395,10 @@ def __init__( ): # Perform the initial warmup at the end of __init__ super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) + GenerationMixin.__init__(self) model_type = config.model_type.replace("_", "-") self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config) - self.model_dtype = kwargs.get("model_dtype", self.dtype) self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: @@ -357,7 +418,15 @@ def __init__( ) except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - self._reorder_cache = self.model_cls._reorder_cache.__get__(self) + + if self._is_ipex_exported: + self._reorder_cache = _ipex_reorder_cache + else: + # Check if _reorder_cache is a static method + if isinstance(self.model_cls.__dict__["_reorder_cache"], staticmethod): + self._reorder_cache = self.model_cls._reorder_cache + else: + self._reorder_cache = self.model_cls._reorder_cache.__get__(self) if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}: self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama @@ -383,7 +452,25 @@ def _prepare_past_key_values(self, input_ids): else: num_attention_heads = self.normalized_config.num_attention_heads - if model_type == "bloom": + if self._is_ipex_exported: + # Indirect access kv cache has a different data layout compared with most transformers model, + # see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache + beam_idx_tmp = torch.zeros( + (self.config.max_position_embeddings, input_ids.shape[0]), dtype=torch.long + ).contiguous() + past_key_values = tuple( + [ + ( + torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), + torch.zeros([1, 1, 1, 1]).contiguous(), + torch.zeros([1, 1, 1, 1]).contiguous(), + beam_idx_tmp, + ) + for i in range(num_layers) + ] + ) + return past_key_values + elif model_type == "bloom": shape_key = (batch_size * num_attention_heads, d_k, 0) shape_value = (batch_size * num_attention_heads, 0, d_k) key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) @@ -505,3 +592,23 @@ def _prepare_inputs_for_generation_for_llama( } ) return model_inputs + + +def _ipex_reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor +) -> Tuple[Tuple[torch.Tensor]]: + # Ipex patched model uses indirect access kv cache which has a different shape with other transformers models + if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1: + for layer_past in past_key_values: + layer_past[3][layer_past[0].size(-2) - 1] = beam_idx + return past_key_values + elif len(past_key_values[0]) == 8: + for layer_past in past_key_values: + layer_past[3][layer_past[0].size(-2) - 1] = beam_idx + layer_past[7][layer_past[0].size(-2) - 1] = beam_idx + return past_key_values + else: + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 03b7d015d1..68119287d8 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -26,6 +26,7 @@ AutoModelForCausalLM, AutoModelForQuestionAnswering, AutoTokenizer, + GenerationConfig, PretrainedConfig, pipeline, set_seed, @@ -42,6 +43,8 @@ IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) +from optimum.intel.utils.import_utils import is_ipex_version +from optimum.utils.testing_utils import grid_parameters SEED = 42 @@ -216,6 +219,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "mpt", "opt", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",) GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 @@ -259,6 +263,41 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + @parameterized.expand( + grid_parameters( + { + "model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES, + "use_cache": [True, False], + } + ) + ) + @unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") + def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache) + self.assertEqual(model.use_cache, use_cache) + trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + # Test with batch_size is 1 and 2. + texts = ["This is a sample", ["This is the first input", "This is the second input"]] + generation_configs = ( + GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=True), + GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True), + GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=True), + GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=True), + GenerationConfig(max_new_tokens=4, do_sample=not use_cache, top_p=1.0, top_k=5, penalty_alpha=0.6), + GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0), + ) + for text in texts: + tokens = tokenizer(text, padding=True, return_tensors="pt") + for generation_config in generation_configs: + outputs = model.generate(**tokens, generation_config=generation_config) + transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs, torch.Tensor) + self.assertEqual(outputs, transformers_outputs) + def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id)