Skip to content

Commit

Permalink
cache position in whisper only with dynamic axis decoder_sequence_length
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 30, 2024
1 parent 825cc6d commit 3fe0cac
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 0 additions & 6 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
check_if_transformers_greater,
is_diffusers_available,
logging,
)
Expand Down Expand Up @@ -291,11 +290,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")

if check_if_transformers_greater("4.43.0"):
# shape is [1] when using cache and [sequence_length] when not using it
common_inputs["cache_position"] = {0: "1"}

else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

Expand Down
8 changes: 7 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_transformers_greater,
is_diffusers_available,
logging,
)
Expand Down Expand Up @@ -1473,7 +1474,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if check_if_transformers_greater("4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs

Expand Down

0 comments on commit 3fe0cac

Please sign in to comment.