From 22d93e74ceffba796d8fb0dd47d99680be4b5608 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 12 Dec 2024 15:17:18 +0200 Subject: [PATCH] Add ONNX export support for MGP-STR (#2099) * Enable mpg-str ONNX export * No longer needed * Improve model patcher * Formatting * `ruff` * Also support image-to-text task * Add unit tests * Add listed support for MGP-STR --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 16 ++++++++++++ optimum/exporters/onnx/model_patcher.py | 26 ++++++++++++++++++++ optimum/exporters/tasks.py | 7 +++++- tests/exporters/exporters_utils.py | 2 ++ tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 6 files changed, 52 insertions(+), 1 deletion(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 57005b85678..46ab3cb8a64 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -65,6 +65,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Marian - MarkupLM - MBart +- MGP-STR - Mistral - MobileBert - MobileVit diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b39d19ec782..85e235f9a96 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -82,6 +82,7 @@ from .model_patcher import ( CLIPModelPatcher, FalconModelPatcher, + MgpstrModelPatcher, MistralModelPatcher, MusicgenModelPatcher, SAMModelPatcher, @@ -933,6 +934,21 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]: return {"x": "pixel_values"} +class MgpstrOnnxConfig(ViTOnnxConfig): + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "char_logits": {0: "batch_size"}, + "bpe_logits": {0: "batch_size"}, + "wp_logits": {0: "batch_size"}, + } + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MgpstrModelPatcher(self, model, model_kwargs=model_kwargs) + + class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures. diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 2c0f9aeba67..083bc127999 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -509,6 +509,32 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward +class MgpstrModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + @functools.wraps(self.orig_forward) + def patched_forward(*args, **kwargs): + signature = inspect.signature(self.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + + # logits is a tuple, so we unpack it and return them as separate outputs + char_logits, bpe_logits, wp_logits = self.orig_forward(*args, **kwargs).logits + + return { + "char_logits": char_logits, + "bpe_logits": bpe_logits, + "wp_logits": wp_logits, + } + + self.patched_forward = patched_forward + + class SAMModelPatcher(ModelPatcher): def __init__( self, diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 0a3758e97cf..ba17730f9d9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -211,7 +211,7 @@ class TasksManager: "image-classification": "AutoModelForImageClassification", "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), "image-to-image": "AutoModelForImageToImage", - "image-to-text": "AutoModelForVision2Seq", + "image-to-text": ("AutoModelForVision2Seq", "AutoModel"), "mask-generation": "AutoModel", "masked-im": "AutoModelForMaskedImageModeling", "multiple-choice": "AutoModelForMultipleChoice", @@ -824,6 +824,11 @@ class TasksManager: "question-answering", onnx="MBartOnnxConfig", ), + "mgp-str": supported_tasks_mapping( + "feature-extraction", + "image-to-text", + onnx="MgpstrOnnxConfig", + ), "mistral": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 32156d9eebf..5f071e0f9eb 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -116,6 +116,7 @@ "marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken "markuplm": "hf-internal-testing/tiny-random-MarkupLMModel", "mbart": "hf-internal-testing/tiny-random-mbart", + "mgp-str": "hf-internal-testing/tiny-random-MgpstrForSceneTextRecognition", "mistral": "echarlaix/tiny-random-mistral", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model", @@ -247,6 +248,7 @@ "marian": "Helsinki-NLP/opus-mt-en-de", "markuplm": "hf-internal-testing/tiny-random-MarkupLMModel", "mbart": "sshleifer/tiny-mbart", + "mgp-str": "alibaba-damo/mgp-str-base", "mobilebert": "google/mobilebert-uncased", # "mobilenet_v1": "google/mobilenet_v1_0.75_192", # "mobilenet_v2": "google/mobilenet_v2_0.35_96", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index cccecd53817..c33c07fc7b1 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -118,6 +118,7 @@ "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "marian": "echarlaix/tiny-random-marian", "mbart": "hf-internal-testing/tiny-random-mbart", + "mgp-str": "hf-internal-testing/tiny-random-MgpstrForSceneTextRecognition", "mistral": "echarlaix/tiny-random-mistral", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192",