diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 8ca6e85be8c..9e808e392b9 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -30,7 +30,6 @@ DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, DummyVisionInputGenerator, - check_if_transformers_greater, is_diffusers_available, logging, ) @@ -291,11 +290,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} self.add_past_key_values(common_inputs, direction="inputs") - - if check_if_transformers_greater("4.43.0"): - # shape is [1] when using cache and [sequence_length] when not using it - common_inputs["cache_position"] = {0: "1"} - else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c39c91b4a6a..4edbeaa4600 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -53,6 +53,7 @@ NormalizedTextConfig, NormalizedTextConfigWithGQA, NormalizedVisionConfig, + check_if_transformers_greater, is_diffusers_available, logging, ) @@ -1473,7 +1474,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self._behavior is not ConfigBehavior.DECODER: common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis. - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: + if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs: + if check_if_transformers_greater("4.43.0"): + # since https://github.com/huggingface/transformers/pull/31166 + common_inputs["cache_position"] = {0: "decoder_sequence_length"} + + if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs: common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs