Skip to content

Commit

Permalink
fix compatibility with legacy models
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 15, 2023
1 parent b643308 commit ca9ce30
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
14 changes: 9 additions & 5 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,10 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: None = None,
**kwargs,
) -> CausalLMOutputWithPast:
# adding use_cache_branch in the signature here is just a hack for IO Binding
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

Expand Down Expand Up @@ -768,6 +770,11 @@ def forward(
if use_cache_branch is not None:
inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch

for output in self.model.get_outputs():
if output.name == "logits" and output.shape[1] == 1:
# TODO : modify the static graph
raise ValueError("The model needs to be re-exported or set use_cache=False.")

outputs = self.model.run(None, inputs)

if self.use_cache:
Expand Down Expand Up @@ -938,11 +945,8 @@ def _from_pretrained(
file_name = decoder_path.name

regular_file_names = []
for regular_file_name in [
ONNX_WEIGHTS_NAME,
ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME,
]:
regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(regular_file_name)
for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]:
regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name)

if file_name not in regular_file_names:
logger.warning(
Expand Down
6 changes: 4 additions & 2 deletions tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,14 @@ def test_optimization_levels_gpu(
use_io_binding=use_io_binding,
)


def test_merged_optimization(self):
ort_model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge")
self.assertTrue(ort_model.use_cache)

with self.assertRaises(NotImplementedError) as cm:
optimizer = ORTOptimizer.from_pretrained(ort_model)

self.assertTrue("ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged" in str(cm.exception))
self.assertTrue(
"ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged"
in str(cm.exception)
)

0 comments on commit ca9ce30

Please sign in to comment.