Skip to content

Commit

Permalink
make input dynamic and enable sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 18, 2024
1 parent 0249b17 commit 9d6d4bb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
6 changes: 5 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ def main_export(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
if (
is_transformers_version(">=", "4.36")
and is_transformers_version("<=", "4.45.0")
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
):
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,13 @@ def patch_model_for_export(
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
return common_inputs


@register_in_tasks_manager(
"t5",
Expand All @@ -2299,6 +2306,13 @@ def patch_model_for_export(
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
return common_inputs


@register_in_tasks_manager(
"mt5",
Expand Down

0 comments on commit 9d6d4bb

Please sign in to comment.