Skip to content

Commit

Permalink
Add ONNX export support for MGP-STR (#2099)
Browse files Browse the repository at this point in the history
* Enable mpg-str ONNX export

* No longer needed

* Improve model patcher

* Formatting

* `ruff`

* Also support image-to-text task

* Add unit tests

* Add listed support for MGP-STR
  • Loading branch information
xenova authored Dec 12, 2024
1 parent 12b3b35 commit 22d93e7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Marian
- MarkupLM
- MBart
- MGP-STR
- Mistral
- MobileBert
- MobileVit
Expand Down
16 changes: 16 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MgpstrModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
SAMModelPatcher,
Expand Down Expand Up @@ -933,6 +934,21 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {"x": "pixel_values"}


class MgpstrOnnxConfig(ViTOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"char_logits": {0: "batch_size"},
"bpe_logits": {0: "batch_size"},
"wp_logits": {0: "batch_size"},
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MgpstrModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures.
Expand Down
26 changes: 26 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,32 @@ def patched_forward(*args, **kwargs):
self.patched_forward = patched_forward


class MgpstrModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

# logits is a tuple, so we unpack it and return them as separate outputs
char_logits, bpe_logits, wp_logits = self.orig_forward(*args, **kwargs).logits

return {
"char_logits": char_logits,
"bpe_logits": bpe_logits,
"wp_logits": wp_logits,
}

self.patched_forward = patched_forward


class SAMModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down
7 changes: 6 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class TasksManager:
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
"image-to-text": ("AutoModelForVision2Seq", "AutoModel"),
"mask-generation": "AutoModel",
"masked-im": "AutoModelForMaskedImageModeling",
"multiple-choice": "AutoModelForMultipleChoice",
Expand Down Expand Up @@ -824,6 +824,11 @@ class TasksManager:
"question-answering",
onnx="MBartOnnxConfig",
),
"mgp-str": supported_tasks_mapping(
"feature-extraction",
"image-to-text",
onnx="MgpstrOnnxConfig",
),
"mistral": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mgp-str": "hf-internal-testing/tiny-random-MgpstrForSceneTextRecognition",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
Expand Down Expand Up @@ -247,6 +248,7 @@
"marian": "Helsinki-NLP/opus-mt-en-de",
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
"mbart": "sshleifer/tiny-mbart",
"mgp-str": "alibaba-damo/mgp-str-base",
"mobilebert": "google/mobilebert-uncased",
# "mobilenet_v1": "google/mobilenet_v1_0.75_192",
# "mobilenet_v2": "google/mobilenet_v2_0.35_96",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "echarlaix/tiny-random-marian",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mgp-str": "hf-internal-testing/tiny-random-MgpstrForSceneTextRecognition",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
Expand Down

0 comments on commit 22d93e7

Please sign in to comment.