Skip to content

Commit

Permalink
Add an initial warmup step to IPEXModels (#543)
Browse files Browse the repository at this point in the history
* Handle autocast in IPEXModel.forward

* Handle missing torch_dtype in config

* Warmup IPEX models at init

* Minor fix

* Fix _init_warmup use_cache condition

* Fix output handling in IPEX question answering
  • Loading branch information
ofirzaf authored Jan 31, 2024
1 parent 8ee487d commit 788e458
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -45,7 +46,7 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..generation.modeling import jit_trace
from ..generation.modeling import jit_trace, prepare_jit_inputs
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask

Expand All @@ -64,6 +65,7 @@ def __init__(
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: bool = True,
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
Expand All @@ -81,6 +83,8 @@ def __init__(
AutoConfig.register(self.base_model_prefix, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)
if warmup:
self._init_warmup()

@classmethod
def _from_transformers(
Expand Down Expand Up @@ -220,6 +224,14 @@ def _call_model(self, *args, **kwargs):
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
for _ in range(2):
self(**dummy_inputs)


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -278,8 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel):
auto_model_class = AutoModelForQuestionAnswering
export_feature = "question-answering"

def forward(self, *args, **kwargs):
outputs = self._call_model(*args, **kwargs)
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
Expand All @@ -295,9 +320,11 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: bool = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir)
# Perform the initial warmup at the end of __init__
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
Expand Down Expand Up @@ -325,6 +352,8 @@ def __init__(
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
if warmup:
self._init_warmup()

def _prepare_past_key_values(self, input_ids):
model_type = self.config.model_type.replace("_", "-")
Expand Down

0 comments on commit 788e458

Please sign in to comment.