Skip to content

Commit

Permalink
Use single GITOnnxConfig class
Browse files Browse the repository at this point in the history
  • Loading branch information
marcindulak committed Dec 20, 2024
1 parent e369e24 commit ea2321c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class OnnxConfig(ExportConfig, ABC):
"fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-classification": OrderedDict({"logits": {0: "batch_size"}}),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"image-text-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-to-image": OrderedDict(
{"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
Expand Down
40 changes: 26 additions & 14 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,21 +2623,33 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.


class GITOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}
}

class GITOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args(vision_config="vision_config")
DUMMY_INPUT_GENERATOR_CLASSES_MAP = {
"feature-extraction": (DummyVisionInputGenerator,),
"image-text-to-text": (DummyTextInputGenerator, DummyVisionInputGenerator,),
"image-to-text": (DummyVisionInputGenerator,),
}

class GITVisionModelOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_inputs_generators = []
for dummy_input_generator in self.DUMMY_INPUT_GENERATOR_CLASSES_MAP[self.task]:
print(self.task, dummy_input_generator)
dummy_input_generator_instantiated = dummy_input_generator(
self.task, self._normalized_config, **kwargs
)
dummy_inputs_generators.append(dummy_input_generator_instantiated)

return dummy_inputs_generators

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
if self.task == "image-text-to-text":
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
}
else:
return {
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
}
6 changes: 1 addition & 5 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class TasksManager:
"AutoModelForInstanceSegmentation",
"AutoModelForUniversalSegmentation",
),
"image-text-to-text": ("AutoModelForCausalLM", "AutoModel"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": ("AutoModelForVision2Seq", "AutoModel"),
"mask-generation": "AutoModel",
Expand Down Expand Up @@ -698,11 +699,6 @@ class TasksManager:
"image-to-text",
onnx="GITOnnxConfig",
),
"git-vision-model": supported_tasks_mapping(
"feature-extraction",
"image-to-text",
onnx="GITVisionModelOnnxConfig",
),
"glpn": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
Expand Down

0 comments on commit ea2321c

Please sign in to comment.