diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cc752779d30..a3c44dbd065 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -118,6 +118,27 @@ def inputs(self) -> Dict[str, Dict[int, str]]: } +class VisualBertOnnxConfig(TextAndVisionOnnxConfig): + DEFAULT_ONNX_OPSET = 11 + + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + "pooler_output": {0: "batch_size"}, + } + + class AlbertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index fdc8bfcb539..e1162593082 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1108,6 +1108,12 @@ class TasksManager: "text-to-audio", onnx="VitsOnnxConfig", ), + "visual-bert": supported_tasks_mapping( + "multiple-choice", + "question-answering", + "feature-extraction", + onnx="VisualBertOnnxConfig", + ), "wavlm": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ccccb5510bf..22996dbe81e 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -197,6 +197,7 @@ "document-question-answering-with-past", ], }, + "visual-bert": "hf-internal-testing/tiny-random-VisualBertModel", } @@ -286,6 +287,7 @@ "speech-to-text": "codenamewei/speech-to-text", "xlm": "xlm-clm-ende-1024", "xlm-roberta": "Unbabel/xlm-roberta-comet-small", + "visual-bert": "uclanlp/visualbert-vqa-coco-pre", } TENSORFLOW_EXPORT_MODELS = {