-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpeechT5 ONNX support #1404
Merged
Merged
SpeechT5 ONNX support #1404
Changes from 13 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
2dd5209
wip
fxmarty be26f71
wip bis
fxmarty 02259a8
nit
fxmarty d181ad2
nit^2
fxmarty 54d3bc7
working export
fxmarty b107b2d
working with-past version
fxmarty f8f69ab
add test
fxmarty 69313a1
add doc
fxmarty b88ed06
working merged onnx
fxmarty 918893e
fix dropout with training=True export
fxmarty 74ba08c
test fix
fxmarty c5a8a1d
fix custom models
fxmarty 2f9661d
some cleaning
fxmarty 5a2ccde
Merge branch 'master' into speecht5-onnx
fxmarty 595ff14
merge mess
fxmarty 563424c
address review comments
fxmarty 2c7a73f
Merge branch 'master' into speecht5-onnx
fxmarty bce548a
fix tests
fxmarty File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,6 +140,7 @@ class OnnxConfig(ExportConfig, ABC): | |
MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION | ||
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None | ||
VARIANTS = {"default": "The default ONNX variant."} | ||
DEFAULT_VARIANT = "default" | ||
_TASK_TO_COMMON_OUTPUTS = { | ||
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}), | ||
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), | ||
|
@@ -200,17 +201,14 @@ def __init__( | |
int_dtype: str = "int64", | ||
float_dtype: str = "fp32", | ||
): | ||
if task not in self._TASK_TO_COMMON_OUTPUTS: | ||
raise ValueError( | ||
f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}" | ||
) | ||
self.task = task | ||
self.int_dtype = int_dtype | ||
self.float_dtype = float_dtype | ||
|
||
self._config = config | ||
self._preprocessors = preprocessors | ||
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) | ||
self.variant = "default" | ||
|
||
def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: | ||
""" | ||
|
@@ -808,7 +806,8 @@ def with_behavior( | |
""" | ||
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): | ||
behavior = ConfigBehavior(behavior) | ||
return self.__class__( | ||
|
||
onnx_config = self.__class__( | ||
self._config, | ||
task=self.task, | ||
int_dtype=self.int_dtype, | ||
|
@@ -818,6 +817,8 @@ def with_behavior( | |
behavior=behavior, | ||
preprocessors=self._preprocessors, | ||
) | ||
onnx_config.variant = self.variant | ||
return onnx_config | ||
|
||
@property | ||
def outputs(self) -> Dict[str, Dict[int, str]]: | ||
|
@@ -902,8 +903,8 @@ def post_process_exported_models( | |
path, models_and_onnx_configs, onnx_files_subpaths | ||
) | ||
|
||
# Attempt to merge only if the decoder was exported without/with past | ||
if self.use_past is True and len(models_and_onnx_configs) == 3: | ||
# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task | ||
if len(onnx_files_subpaths) >= 3 and self.use_past is True or self.variant == "with-past": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure. I'll need to double check. |
||
decoder_path = Path(path, onnx_files_subpaths[1]) | ||
decoder_with_past_path = Path(path, onnx_files_subpaths[2]) | ||
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") | ||
|
@@ -922,7 +923,8 @@ def post_process_exported_models( | |
# In order to do the validation of the two branches on the same file | ||
encoder_path = onnx_files_subpaths[0] | ||
|
||
onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name] | ||
onnx_files_subpaths_new = [encoder_path, decoder_merged_path.name, decoder_merged_path.name] | ||
onnx_files_subpaths_new.extend(onnx_files_subpaths[3:]) | ||
|
||
# We validate the two branches of the decoder model then | ||
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True | ||
|
@@ -933,8 +935,10 @@ def post_process_exported_models( | |
|
||
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True | ||
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True | ||
else: | ||
onnx_files_subpaths_new = onnx_files_subpaths | ||
|
||
return models_and_onnx_configs, onnx_files_subpaths | ||
return models_and_onnx_configs, onnx_files_subpaths_new | ||
|
||
def generate_dummy_inputs_for_validation( | ||
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None | ||
|
@@ -1006,6 +1010,7 @@ def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: st | |
self.float_dtype = float_dtype | ||
self._normalized_config = self._onnx_config._normalized_config | ||
self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS | ||
self.variant = "default" | ||
|
||
@classmethod | ||
def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss": | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty this line currently does nothing since it is set to False again in line 381. Do you want to have a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I'll fix