diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index b5129c23f2..a10be0083c 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -47,6 +47,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - ESM - Falcon - Flaubert +- GIT - GPT-2 - GPT-BigCode - GPT-J diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index b5adb4522a..210bdf73d8 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -150,6 +150,7 @@ class OnnxConfig(ExportConfig, ABC): "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-classification": OrderedDict({"logits": {0: "batch_size"}}), "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), + "image-text-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-to-image": OrderedDict( {"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 3a48a579c2..d1a169e356 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2621,3 +2621,35 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + + +class GITOnnxConfig(TextAndVisionOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args(vision_config="vision_config") + DUMMY_INPUT_GENERATOR_CLASSES_MAP = { + "feature-extraction": (DummyVisionInputGenerator,), + "image-text-to-text": (DummyTextInputGenerator, DummyVisionInputGenerator,), + "image-to-text": (DummyVisionInputGenerator,), + } + + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_inputs_generators = [] + for dummy_input_generator in self.DUMMY_INPUT_GENERATOR_CLASSES_MAP[self.task]: + print(self.task, dummy_input_generator) + dummy_input_generator_instantiated = dummy_input_generator( + self.task, self._normalized_config, **kwargs + ) + dummy_inputs_generators.append(dummy_input_generator_instantiated) + + return dummy_inputs_generators + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self.task == "image-text-to-text": + return { + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, + "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } + else: + return { + "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7cb5a31d2d..d748de7a08 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -215,6 +215,7 @@ class TasksManager: "AutoModelForInstanceSegmentation", "AutoModelForUniversalSegmentation", ), + "image-text-to-text": ("AutoModelForCausalLM", "AutoModel"), "image-to-image": "AutoModelForImageToImage", "image-to-text": ("AutoModelForVision2Seq", "AutoModel"), "mask-generation": "AutoModel", @@ -692,6 +693,12 @@ class TasksManager: "text-classification", onnx="GemmaOnnxConfig", ), + "git": supported_tasks_mapping( + "feature-extraction", + "image-text-to-text", + "image-to-text", + onnx="GITOnnxConfig", + ), "glpn": supported_tasks_mapping( "feature-extraction", "depth-estimation", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 900b5f3b5c..a46013d8d3 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -96,6 +96,14 @@ }, "flaubert": "hf-internal-testing/tiny-random-flaubert", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", + "git": { + "hf-internal-testing/tiny-random-GitModel": [ + "feature-extraction", + ], + "hf-internal-testing/tiny-random-GitForCausalLM": [ + "image-text-to-text", + ], + }, "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 02ced3be3a..5215552e6b 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -101,6 +101,14 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", + "git": { + "hf-internal-testing/tiny-random-GitModel": [ + "feature-extraction", + ], + "hf-internal-testing/tiny-random-GitForCausalLM": [ + "image-text-to-text", + ], + }, "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",