Skip to content

Commit

Permalink
Remove position_ids generation in IPEXModel forward (#566)
Browse files Browse the repository at this point in the history
* fix jit model

* rm autocast in model

* support assisted decoding and add reorder cache function

* add comment for _prepare_past_key_values

* rebase main

* fix model_dtype

* rm useless comments

* fix class name

* revert _call_model

* fix model_dtype warning liog

* testiong low precision ipex model

* add assisted decoding

* remove low-precision testing as CI node does not support bf16

* fix conflict

* remove prepare position_ids in forward

* fix code style
  • Loading branch information
jiqing-feng authored Mar 19, 2024
1 parent 45dab01 commit 9813f90
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
6 changes: 0 additions & 6 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,6 @@ def forward(
"attention_mask": attention_mask,
}

if "position_ids" in self.input_names and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

if "position_ids" in self.input_names or not self.input_names:
inputs["position_ids"] = position_ids

Expand Down
24 changes: 18 additions & 6 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
set_seed,
)

from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.intel import (
IPEXModel,
IPEXModelForAudioClassification,
Expand Down Expand Up @@ -236,11 +235,8 @@ def test_compare_to_transformers(self, model_arch):
return_tensors="pt",
return_token_type_ids=False if model_arch in ("llama", "llama2") else None,
)
position_ids = None
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
input_shape = tokens["input_ids"].shape
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
outputs = ipex_model(**tokens, position_ids=position_ids)
inputs = ipex_model.prepare_inputs_for_generation(**tokens)
outputs = ipex_model(**inputs)

self.assertIsInstance(outputs.logits, torch.Tensor)
self.assertIsInstance(outputs.past_key_values, (tuple, list))
Expand All @@ -263,6 +259,22 @@ def test_pipeline(self, model_arch):
self.assertEqual(pipe.device, model.device)
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_assisted_decoding(self, model_arch):
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")
ipex_output = ipex_model.generate(**tokens, do_sample=False)
ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model)
transformers_output = transformers_model.generate(**tokens, do_sample=False)
transformers_output_assisted = transformers_model.generate(
**tokens, do_sample=False, assistant_model=ipex_model
)
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted))
self.assertTrue(torch.equal(transformers_output, transformers_output_assisted))

@parameterized.expand(
grid_parameters(
{
Expand Down

0 comments on commit 9813f90

Please sign in to comment.