From ea2321cd68cd63668454a2958e5fc923bb5e8950 Mon Sep 17 00:00:00 2001 From: marcindulak Date: Fri, 20 Dec 2024 20:37:01 +0100 Subject: [PATCH] Use single GITOnnxConfig class --- optimum/exporters/onnx/base.py | 1 + optimum/exporters/onnx/model_configs.py | 40 ++++++++++++++++--------- optimum/exporters/tasks.py | 6 +--- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index b5adb4522a..210bdf73d8 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -150,6 +150,7 @@ class OnnxConfig(ExportConfig, ABC): "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-classification": OrderedDict({"logits": {0: "batch_size"}}), "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), + "image-text-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-to-image": OrderedDict( {"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 73739a041f..d1a169e356 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2623,21 +2623,33 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. -class GITOnnxConfig(VisionOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig - DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator) - - @property - def inputs(self) -> Dict[str, Dict[int, str]]: - return { - "input_ids": {0: "text_batch_size", 1: "sequence_length"}, - "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"} - } - +class GITOnnxConfig(TextAndVisionOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args(vision_config="vision_config") + DUMMY_INPUT_GENERATOR_CLASSES_MAP = { + "feature-extraction": (DummyVisionInputGenerator,), + "image-text-to-text": (DummyTextInputGenerator, DummyVisionInputGenerator,), + "image-to-text": (DummyVisionInputGenerator,), + } -class GITVisionModelOnnxConfig(VisionOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_inputs_generators = [] + for dummy_input_generator in self.DUMMY_INPUT_GENERATOR_CLASSES_MAP[self.task]: + print(self.task, dummy_input_generator) + dummy_input_generator_instantiated = dummy_input_generator( + self.task, self._normalized_config, **kwargs + ) + dummy_inputs_generators.append(dummy_input_generator_instantiated) + return dummy_inputs_generators + @property def inputs(self) -> Dict[str, Dict[int, str]]: - return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} + if self.task == "image-text-to-text": + return { + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, + "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } + else: + return { + "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 1fbbbc5ff1..d748de7a08 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -215,6 +215,7 @@ class TasksManager: "AutoModelForInstanceSegmentation", "AutoModelForUniversalSegmentation", ), + "image-text-to-text": ("AutoModelForCausalLM", "AutoModel"), "image-to-image": "AutoModelForImageToImage", "image-to-text": ("AutoModelForVision2Seq", "AutoModel"), "mask-generation": "AutoModel", @@ -698,11 +699,6 @@ class TasksManager: "image-to-text", onnx="GITOnnxConfig", ), - "git-vision-model": supported_tasks_mapping( - "feature-extraction", - "image-to-text", - onnx="GITVisionModelOnnxConfig", - ), "glpn": supported_tasks_mapping( "feature-extraction", "depth-estimation",