diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 474628d546b..0ea17b6afec 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -83,6 +83,7 @@ Supported architectures: - SEW - SEW-D - Speech2Text +- SpeechT5 - Splinter - SqueezeBert - Stable Diffusion diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index 85661ccf6cf..55f8b9dc1d3 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -14,6 +14,7 @@ """Defines the command line for the export with ONNX.""" import argparse +import json from pathlib import Path from typing import TYPE_CHECKING @@ -136,6 +137,20 @@ def parse_args_onnx(parser): default=None, help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"), ) + optional_group.add_argument( + "--model-kwargs", + type=json.loads, + help=("Any kwargs passed to the model forward, or used to customize the export for a given model."), + ) + optional_group.add_argument( + "--legacy", + action="store_true", + help=( + "Export decoder only models in three files (without + with past and the resulting merged model)." + "Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." + ), + ) + input_group = parser.add_argument_group( "Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)." ) @@ -209,14 +224,6 @@ def parse_args_onnx(parser): default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"], help="For Segment Anything. It corresponds to the number of points per segmentation masks.", ) - optional_group.add_argument( - "--legacy", - action="store_true", - help=( - "Export decoder only models in three files (without + with past and the resulting merged model)." - "Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." - ), - ) # deprecated argument parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS) @@ -256,5 +263,6 @@ def run(self): _variant=self.args.variant, library_name=self.args.library_name, legacy=self.args.legacy, + model_kwargs=self.args.model_kwargs, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 1b601cdfb8d..851be1b8f6f 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -19,7 +19,7 @@ from pathlib import Path from requests.exceptions import ConnectionError as RequestsConnectionError -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from transformers.utils import is_torch_available from ...commands.export.onnx import parse_args_onnx @@ -38,6 +38,7 @@ get_decoder_models_for_export, get_encoder_decoder_models_for_export, get_sam_models_for_export, + get_speecht5_models_for_export, get_stable_diffusion_models_for_export, ) @@ -69,6 +70,7 @@ def _get_submodels_and_onnx_configs( fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, legacy: bool = False, + model_kwargs: Optional[Dict] = None, ): is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: @@ -95,10 +97,11 @@ def _get_submodels_and_onnx_configs( onnx_config.variant = _variant all_variants = "\n".join( - [f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()] + [f" - {name}: {description}" for name, description in onnx_config.VARIANTS.items()] ) logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}") + # TODO: this succession of if/else strongly suggests a refactor is needed. if ( model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS) @@ -109,6 +112,8 @@ def _get_submodels_and_onnx_configs( models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy) elif model.config.model_type == "sam": models_and_onnx_configs = get_sam_models_for_export(model, onnx_config) + elif model.config.model_type == "speecht5": + models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs) else: models_and_onnx_configs = {"model": (model, onnx_config)} @@ -333,6 +338,30 @@ def main_export( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) + if library_name == "transformers": + config = AutoConfig.from_pretrained( + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) + model_type = config.model_type.replace("_", "-") + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + custom_architecture = True + elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"): + if original_task == "auto": + autodetected_message = " (auto-detected)" + else: + autodetected_message = "" + model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx") + raise ValueError( + f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." + ) + model = TasksManager.get_model_from_task( task, model_name_or_path, @@ -361,18 +390,16 @@ def main_export( if not is_stable_diffusion: if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE: raise ValueError( - f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " + f"{model_type} is not supported yet. Only {list(TasksManager._SUPPORTED_CLI_MODEL_TYPE.keys())} are supported. " f"If you want to support {model_type} please propose a PR or open up an issue." ) - if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( - task, exporter="onnx" - ): + if model.config.model_type.replace("_", "-") not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True # TODO: support onnx_config.py in the model repo if custom_architecture and custom_onnx_configs is None: raise ValueError( - f"Trying to export a {model.config.model_type.replace('-', '_')} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. For the task {task}, the Optimum ONNX exporter supports natively the architectures: {TasksManager.get_supported_model_type_for_task(task, exporter='onnx')}." + f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export." ) if custom_architecture and original_task == "auto": @@ -425,6 +452,7 @@ def main_export( preprocessors=preprocessors, _variant=_variant, legacy=legacy, + model_kwargs=model_kwargs, ) if not is_stable_diffusion: diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index a65374346ac..1e5704e8937 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -140,6 +140,7 @@ class OnnxConfig(ExportConfig, ABC): MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION PATCHING_SPECS: Optional[List["PatchingSpec"]] = None VARIANTS = {"default": "The default ONNX variant."} + DEFAULT_VARIANT = "default" _TASK_TO_COMMON_OUTPUTS = { "audio-classification": OrderedDict({"logits": {0: "batch_size"}}), "audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), @@ -200,10 +201,6 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", ): - if task not in self._TASK_TO_COMMON_OUTPUTS: - raise ValueError( - f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}" - ) self.task = task self.int_dtype = int_dtype self.float_dtype = float_dtype @@ -211,6 +208,7 @@ def __init__( self._config = config self._preprocessors = preprocessors self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self.variant = "default" def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: """ @@ -808,7 +806,8 @@ def with_behavior( """ if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): behavior = ConfigBehavior(behavior) - return self.__class__( + + onnx_config = self.__class__( self._config, task=self.task, int_dtype=self.int_dtype, @@ -818,6 +817,8 @@ def with_behavior( behavior=behavior, preprocessors=self._preprocessors, ) + onnx_config.variant = self.variant + return onnx_config @property def outputs(self) -> Dict[str, Dict[int, str]]: @@ -902,8 +903,8 @@ def post_process_exported_models( path, models_and_onnx_configs, onnx_files_subpaths ) - # Attempt to merge only if the decoder was exported without/with past - if self.use_past is True and len(models_and_onnx_configs) == 3: + # Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task + if len(onnx_files_subpaths) >= 3 and self.use_past is True: decoder_path = Path(path, onnx_files_subpaths[1]) decoder_with_past_path = Path(path, onnx_files_subpaths[2]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") @@ -922,7 +923,8 @@ def post_process_exported_models( # In order to do the validation of the two branches on the same file encoder_path = onnx_files_subpaths[0] - onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name] + onnx_files_subpaths_new = [encoder_path, decoder_merged_path.name, decoder_merged_path.name] + onnx_files_subpaths_new.extend(onnx_files_subpaths[3:]) # We validate the two branches of the decoder model then models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True @@ -933,8 +935,10 @@ def post_process_exported_models( models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True + else: + onnx_files_subpaths_new = onnx_files_subpaths - return models_and_onnx_configs, onnx_files_subpaths + return models_and_onnx_configs, onnx_files_subpaths_new def generate_dummy_inputs_for_validation( self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None @@ -1006,6 +1010,7 @@ def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: st self.float_dtype = float_dtype self._normalized_config = self._onnx_config._normalized_config self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS + self.variant = "default" @classmethod def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss": diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index f637da07804..0b00667e6c8 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -38,6 +38,7 @@ ) from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError from .base import OnnxConfig +from .model_configs import SpeechT5OnnxConfig from .utils import PickableInferenceSession, recursive_to_device @@ -142,7 +143,6 @@ def validate_models_outputs( if use_subprocess: logger.info("Validating models in subprocesses...") exceptions = [] # run all validations before raising - onnx_paths = [] for i, model_name in enumerate(models_and_onnx_configs.keys()): submodel, sub_onnx_config = models_and_onnx_configs[model_name] onnx_model_path = ( @@ -150,7 +150,6 @@ def validate_models_outputs( if onnx_files_subpaths is not None else output_dir.joinpath(model_name + ".onnx") ) - onnx_paths.append(onnx_model_path) try: # Model validation is done in subprocesses, as ONNX Runtime has the bad habit of # not releasing memory once an InferenceSession is initialized. @@ -168,12 +167,12 @@ def validate_models_outputs( model_kwargs=model_kwargs, ) except Exception as e: - exceptions.append(e) + exceptions.append((onnx_model_path, e)) if len(exceptions) != 0: for i, exception in enumerate(exceptions[:-1]): - logger.error(f"Validation {i} for the model {onnx_paths[i].as_posix()} raised: {exception}") - raise exceptions[-1] + logger.error(f"Validation for the model {exception[0].as_posix()} raised: {exception[1]}") + raise exceptions[-1][1] def validate_model_outputs( @@ -423,9 +422,11 @@ def _run_validation( if value_failures: msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures) - raise AtolError( - f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}" - ) + atol_msg = f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}" + + if isinstance(config, SpeechT5OnnxConfig): + atol_msg += "\nIMPORTANT NOTE: SpeechT5 uses a dropout at inference and the output validation of ONNX Runtime inference vs PyTorch is expected to fail. Reference: https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L727" + raise AtolError(atol_msg) class ValidationProcess(mp.Process): @@ -526,7 +527,7 @@ def export_pytorch( with torch.no_grad(): model.config.return_dict = True - model.eval() + model = model.eval() # Check if we need to override certain configuration item if config.values_override is not None: diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 4c96bbdbe9a..39276017103 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -24,11 +24,13 @@ BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, DummyDecoderTextInputGenerator, + DummyInputGenerator, DummyPastKeyValuesGenerator, DummyPix2StructInputGenerator, DummyPointsGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, + DummySpeechT5InputGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, DummyVisionEmbeddingsGenerator, @@ -65,6 +67,7 @@ MistralModelPatcher, OPTModelPatcher, SAMModelPatcher, + SpeechT5ModelPatcher, WavLMModelPatcher, ) @@ -73,7 +76,6 @@ from transformers import PretrainedConfig from transformers.modeling_utils import PreTrainedModel - from ...utils import DummyInputGenerator from .model_patcher import ModelPatcher if is_tf_available(): @@ -1298,6 +1300,151 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs +class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + # TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943), + # so we won't support for now. + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(decoder_num_layers="decoder_layers") + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( + hidden_size="hidden_size", + num_attention_heads="encoder_attention_heads", # TODO: bugged in case encoder and decoder have different number of heads + encoder_num_layers="encoder_layers", + decoder_num_layers="decoder_layers", + allow_new=True, + ) + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTextInputGenerator, + DummySeq2SeqDecoderTextInputGenerator, + DummySeq2SeqPastKeyValuesGenerator, + DummySpeechT5InputGenerator, + ) + DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator + + VARIANTS = { + "with-past": "The export follows the Transformers implementation using the KV cache, with the following components exported:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.", + "without-past": "The same as `with-past`, just without KV cache support. This is not a recommended export as slower than `with-past`.", + } + DEFAULT_VARIANT = "with-past" + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + is_postnet_and_vocoder: bool = False, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=preprocessors, + ) + if float_dtype == "fp16": + raise ValueError( + "The ONNX export of SpeechT5 in float16 is currently not supported due to a bug in PyTorch: https://github.com/pytorch/pytorch/pull/110078. Please open an issue in Optimum if you would like to export SpeechT5 in float16." + ) + self.is_postnet_and_vocoder = is_postnet_and_vocoder + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = {} + + # Batched inference is not supported in Transformers. + if self._behavior is ConfigBehavior.ENCODER: + common_inputs["input_ids"] = {1: "encoder_sequence_length"} + elif self._behavior is ConfigBehavior.DECODER: + # NOTE: even when past is used, the decoder takes the full sequence as input as the prenet seem to require it: + # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2573 + common_inputs["output_sequence"] = {1: "decoder_sequence_length"} + common_inputs["speaker_embeddings"] = {} # No dynamic shape here. + common_inputs["encoder_outputs"] = {1: "encoder_sequence_length"} + common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} + + if self.variant == "with-past" and self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + elif self.is_postnet_and_vocoder: + common_inputs["spectrogram"] = {0: "n_spectrums x reduction_factor"} + else: + raise ValueError( + "self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen." + ) + + return common_inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = {} + if self._behavior is ConfigBehavior.ENCODER: + common_outputs["encoder_outputs"] = {1: "encoder_sequence_length"} + common_outputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} + elif self._behavior is ConfigBehavior.DECODER: + common_outputs["output_sequence_out"] = {1: "decoder_sequence_length + 1"} + common_outputs["spectrum"] = {} # No dynamic shape here. + common_outputs["prob"] = {} # No dynamic shape here. + + if self.variant == "with-past" and self.use_past: + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. + self.add_past_key_values(common_outputs, direction="outputs") + elif self.is_postnet_and_vocoder: + common_outputs["waveform"] = {0: "n_samples"} + else: + raise ValueError( + "self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen." + ) + + return common_outputs + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs) + + @property + def torch_to_onnx_input_map(self) -> Dict[str, str]: + return {"encoder_outputs": "encoder_hidden_states"} + + def overwrite_shape_and_generate_input( + self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict + ): + dummy_input_gen.batch_size = 1 + dummy_input = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + return dummy_input + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_decoder_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_decoder_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.decoder_num_layers): + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {2: decoder_sequence_name} + + if ( + self.is_merged is True + or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) + or direction == "inputs" + ): + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {2: "encoder_sequence_length_out"} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {2: "encoder_sequence_length_out"} + + class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): shape = [self.batch_size, self.sequence_length, self.normalized_config.input_features_per_channel] diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index b6f5a4dcd82..4dd584e6f3c 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -21,6 +21,7 @@ import transformers from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.falcon.modeling_falcon import FalconModel, build_alibi_tensor +from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available from ...utils.modeling_utils import ( @@ -586,6 +587,170 @@ def patched_forward( self.patched_forward = patched_forward +def patched_speecht5_prenet_forward( + self, + input_values: torch.Tensor, + speaker_embeddings: Optional[torch.Tensor] = None, +): + # Dropout is always applied, even when evaluating. See ยง2.2 in https://arxiv.org/abs/1712.05884. + + inputs_embeds = input_values + for layer in self.layers: + inputs_embeds = torch.nn.functional.relu(layer(inputs_embeds)) + + # NOTE: we patch the prenet to avoid using torch.nn.functional.dropout, that is exported as a `Dropout` node in the ONNX + # that is ignored during inference by some runtimes as ONNX Runtime. + # Reference: https://github.com/microsoft/onnxruntime/issues/9333 & https://github.com/microsoft/onnxruntime/issues/5549 + mask = torch.rand(inputs_embeds.shape, device=inputs_embeds.device) > self.config.speech_decoder_prenet_dropout + inputs_embeds = inputs_embeds * mask / (1 - self.config.speech_decoder_prenet_dropout) + + # inputs_embeds = nn.functional.dropout( + # inputs_embeds, self.config.speech_decoder_prenet_dropout, training=True + # ) + + inputs_embeds = self.final_layer(inputs_embeds) + inputs_embeds = self.encode_positions(inputs_embeds) + + if speaker_embeddings is not None: + speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings) + speaker_embeddings = speaker_embeddings.unsqueeze(1) + speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1) + inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) + inputs_embeds = torch.nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) + + return inputs_embeds + + +class SpeechT5ModelPatcher(ModelPatcher): + def __enter__(self): + self.patch_ops() + self._model.speecht5.decoder.prenet.forward = types.MethodType( + patched_speecht5_prenet_forward, self._model.speecht5.decoder.prenet + ) + setattr(self._model, self.orig_forward_name, self.patched_forward) + + def __exit__(self, exc_type, exc_value, traceback): + self.restore_ops() + setattr(self._model, self.orig_forward_name, self.orig_forward) + self._model.speecht5.decoder.prenet.forward = types.MethodType( + self.original_speecht5_prenet_forward, self._model.speecht5.decoder.prenet + ) + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + self.original_speecht5_prenet_forward = model.speecht5.decoder.prenet.forward + + model.vocoder = model_kwargs["vocoder_model"].eval() + + def patched_forward( + input_ids=None, + speaker_embeddings=None, + encoder_outputs=None, + past_key_values=None, + output_sequence=None, + spectrogram=None, + encoder_attention_mask=None, + ): + use_cache = self.real_config.use_past and self.real_config.variant == "with-past" + if self.real_config._behavior == "encoder": + encoder_attention_mask = torch.ones_like(input_ids) + + encoder_out = model.speecht5.encoder( + input_values=input_ids, + attention_mask=encoder_attention_mask, + return_dict=True, + ) + # downsample encoder attention mask + if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( + encoder_out[0].shape[1], encoder_attention_mask + ) + + result = { + "encoder_outputs": encoder_out.last_hidden_state, + "encoder_attention_mask": encoder_attention_mask, + } + + elif self.real_config._behavior == "decoder": + # TODO: and self.real_config.use_past_in_inputs + encoder_hidden_states = encoder_outputs[0] + + decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) + + # Run the decoder layers on the last element of the prenet output. + decoder_out = model.speecht5.decoder.wrapped_decoder( + hidden_states=decoder_hidden_states[:, -1:], + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=False, + return_dict=True, + ) + + last_decoder_output = decoder_out.last_hidden_state[0, -1] + past_key_values = decoder_out.past_key_values + + # Predict the new mel spectrum for this step in the sequence. + spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins) + + # NOTE: extending the spectrogram should is to be handled outside of the ONNX. + # spectrogram.append(spectrum) + + # Extend the output sequence with the new mel spectrum. + output_sequence = torch.cat( + (output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1 + ) + + # Predict the probability that this is the stop token. + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + + result = { + "output_sequence_out": output_sequence, + "spectrum": spectrum, + "prob": prob, + "past_key_values": past_key_values, + } + elif self.real_config.is_postnet_and_vocoder: + # NOTE: the following concatenation is expected to be handled outside of the ONNX: + # spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0) + spectrogram = spectrogram.unsqueeze(0) + spectrogram = model.speech_decoder_postnet.postnet(spectrogram) + spectrogram = spectrogram.squeeze(0) + + waveform = model.vocoder(spectrogram) + + result = {"waveform": waveform} + else: + raise ValueError("Should not happen") + + # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. + filterd_outputs = {} + for name, value in result.items(): + if name != "past_key_values": + filterd_outputs[name] = value + else: + if self.real_config._behavior == "decoder" and ( + self.real_config.is_merged or not self.real_config.use_past_in_inputs + ): + filterd_outputs[name] = value + elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: + # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. + filterd_outputs[name] = tuple([v[:2] for v in value]) + + return filterd_outputs + + self.patched_forward = patched_forward + + class CausalAttentionMaskModelPatcher(ModelPatcher): def __init__( self, diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 6fce8b4f2d8..ef6206e8d06 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -19,6 +19,7 @@ import torch from packaging import version +from transformers.models.speecht5.modeling_speecht5 import SpeechT5HifiGan from transformers.utils import is_tf_available, is_torch_available from ...utils import ( @@ -377,7 +378,7 @@ def _get_submodels_for_export_sam(model, variant): if variant == "monolith": models_for_export["model"] = model else: - # We use the model patcher to patch their forward method. + # We rather use the model patcher to patch their forward method. models_for_export["vision_encoder"] = model models_for_export["prompt_encoder_mask_decoder"] = model @@ -406,6 +407,63 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel return models_for_export +def get_speecht5_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", model_kwargs: Optional[Dict] +): + if model_kwargs is None or "vocoder" not in model_kwargs: + raise ValueError( + 'The ONNX export of SpeechT5 requires a vocoder. Please pass `--model-kwargs \'{"vocoder": "vocoder_model_name_or_path"}\'` from the command line, or `model_kwargs={"vocoder": "vocoder_model_name_or_path"}` if calling main_export.' + ) + + models_for_export = {} + + # We rather use the model patcher to patch their forward method. + models_for_export["encoder_model"] = model + models_for_export["decoder_model"] = model + + if config.variant == "with-past": + models_for_export["decoder_with_past_model"] = model + + # TODO: more flexibility in the vocoder class? + vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"]).eval() + model_kwargs["vocoder_model"] = vocoder + + models_for_export["decoder_postnet_and_vocoder"] = model + + encoder_onnx_config = config.with_behavior("encoder") + + use_past = config.variant == "with-past" + decoder_onnx_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False) + + models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config) + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config) + if config.variant == "with-past": + decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + decoder_onnx_config_with_past, + ) + + postnet_and_vocoder_onnx_config = config.__class__( + config._config, + task=config.task, + int_dtype=config.int_dtype, + float_dtype=config.float_dtype, + use_past=use_past, + use_past_in_inputs=False, # Irrelevant here. + behavior=config._behavior, # Irrelevant here. + preprocessors=config._preprocessors, + is_postnet_and_vocoder=True, + ) + postnet_and_vocoder_onnx_config.variant = config.variant + models_for_export["decoder_postnet_and_vocoder"] = ( + models_for_export["decoder_postnet_and_vocoder"], + postnet_and_vocoder_onnx_config, + ) + + return models_for_export + + def override_diffusers_2_0_attn_processors(model): for _, submodule in model.named_modules(): if isinstance(submodule, Attention): diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 89107fa053a..6a0e01c5268 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -159,26 +159,27 @@ class TasksManager: # task in a Hub repo that has no pipeline_tag, and no transformersInfo.pipeline_tag, as we then rely on # on transformersInfo["auto_model"] and this dictionary. _TRANSFORMERS_TASKS_TO_MODEL_LOADERS = { + "audio-classification": "AutoModelForAudioClassification", + "audio-frame-classification": "AutoModelForAudioFrameClassification", + "audio-xvector": "AutoModelForAudioXVector", + "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), "conversational": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "feature-extraction": "AutoModel", "fill-mask": "AutoModelForMaskedLM", - "text-generation": "AutoModelForCausalLM", - "text2text-generation": "AutoModelForSeq2SeqLM", - "text-classification": "AutoModelForSequenceClassification", - "token-classification": "AutoModelForTokenClassification", - "multiple-choice": "AutoModelForMultipleChoice", - "object-detection": "AutoModelForObjectDetection", - "question-answering": "AutoModelForQuestionAnswering", "image-classification": "AutoModelForImageClassification", "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), + "image-to-text": "AutoModelForVision2Seq", "mask-generation": "AutoModel", "masked-im": "AutoModelForMaskedImageModeling", + "multiple-choice": "AutoModelForMultipleChoice", + "object-detection": "AutoModelForObjectDetection", + "question-answering": "AutoModelForQuestionAnswering", "semantic-segmentation": "AutoModelForSemanticSegmentation", - "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), - "audio-classification": "AutoModelForAudioClassification", - "audio-frame-classification": "AutoModelForAudioFrameClassification", - "audio-xvector": "AutoModelForAudioXVector", - "image-to-text": "AutoModelForVision2Seq", + "text-to-audio": "AutoModelForTextToSpectrogram", + "text-generation": "AutoModelForCausalLM", + "text2text-generation": "AutoModelForSeq2SeqLM", + "text-classification": "AutoModelForSequenceClassification", + "token-classification": "AutoModelForTokenClassification", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } @@ -229,22 +230,23 @@ class TasksManager: } _SYNONYM_TASK_MAP = { - "sequence-classification": "text-classification", + "audio-ctc": "automatic-speech-recognition", "causal-lm": "text-generation", "causal-lm-with-past": "text-generation-with-past", + "default": "feature-extraction", + "default-with-past": "feature-extraction-with-past", + "masked-lm": "fill-mask", + "mask-generation": "feature-extraction", + "sentence-similarity": "feature-extraction", "seq2seq-lm": "text2text-generation", "seq2seq-lm-with-past": "text2text-generation-with-past", + "sequence-classification": "text-classification", "speech2seq-lm": "automatic-speech-recognition", "speech2seq-lm-with-past": "automatic-speech-recognition-with-past", - "masked-lm": "fill-mask", - "mask-generation": "feature-extraction", - "vision2seq-lm": "image-to-text", - "default": "feature-extraction", - "default-with-past": "feature-extraction-with-past", - "audio-ctc": "automatic-speech-recognition", - "translation": "text2text-generation", - "sentence-similarity": "feature-extraction", "summarization": "text2text-generation", + "text-to-speech": "text-to-audio", + "translation": "text2text-generation", + "vision2seq-lm": "image-to-text", "zero-shot-classification": "text-classification", } @@ -268,12 +270,12 @@ class TasksManager: # TODO: why feature-extraction-with-past is here? _ENCODER_DECODER_TASKS = ( - "text2text-generation", "automatic-speech-recognition", - "image-to-text", + "document-question-answering", "feature-extraction-with-past", + "image-to-text", + "text2text-generation", "visual-question-answering", - "document-question-answering", ) # TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM @@ -857,6 +859,11 @@ class TasksManager: "automatic-speech-recognition-with-past", onnx="Speech2TextOnnxConfig", ), + # TODO: SpeechT5 can also support audio-to-audio and automatic-speech-recognition. + "speecht5": supported_tasks_mapping( + "text-to-audio", + onnx="SpeechT5OnnxConfig", + ), "splinter": supported_tasks_mapping( "feature-extraction", "question-answering", @@ -1065,12 +1072,12 @@ def get_supported_tasks_for_model_type( `TaskNameToExportConfigDict`: The dictionary mapping each task to a corresponding `ExportConfig` constructor. """ - model_type = model_type.lower() + model_type = model_type.lower().replace("_", "-") model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: raise KeyError( f"{model_type_and_model_name} is not supported yet. " - f"Only {TasksManager._SUPPORTED_MODEL_TYPE} are supported. " + f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE.keys())} are supported. " f"If you want to support {model_type} please propose a PR or open up an issue." ) elif exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]: diff --git a/optimum/onnx/transformations_utils.py b/optimum/onnx/transformations_utils.py index c5b3ad417ba..05931753bfd 100644 --- a/optimum/onnx/transformations_utils.py +++ b/optimum/onnx/transformations_utils.py @@ -178,16 +178,20 @@ def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto, strict: bool): if strict is False and model_output_1.name not in model2_outputs: data_type = model_output_1.type.tensor_type.elem_type dims_output_1 = _infer_output_shape(model_output_1) - if not isinstance(dims_output_1[0], str): + if not any(isinstance(dim_output, str) for dim_output in dims_output_1): raise ValueError( - f"Expected a dynamic shape for the axis zero of {model_output_1.name}, found a static shape: {dims_output_1[0]}" + f"Expected at least one dynamic input shape for the output {model_output_1.name}, found a static shape: {dims_output_1}" ) - # fill the constant shape with the original shape, except for the axis zero that is 0 for an empty constant, + # fill the constant shape with the original shape, except for the first dynamic axis that is 0 for an empty constant, # and the dynamic axis set to 1 - dims_dummy_output = [0] - for dim in dims_output_1[1:]: - if isinstance(dim, str): + dims_dummy_output = [] + dummy_axis = None + for j, dim in enumerate(dims_output_1): + if isinstance(dim, str) and dummy_axis is None: + dims_dummy_output.append(0) + dummy_axis = j + elif isinstance(dim, str) and dummy_axis is not None: dims_dummy_output.append(1) else: dims_dummy_output.append(dim) diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 1555f846f32..4da66ddbab9 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -55,6 +55,7 @@ DummyPointsGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, + DummySpeechT5InputGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, DummyVisionEmbeddingsGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 765f489a341..7c067922e8b 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -323,6 +323,7 @@ class DummyTextInputGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ( "input_ids", "attention_mask", + "encoder_attention_mask", "token_type_ids", "position_ids", ) @@ -982,3 +983,39 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) for _ in range(self.num_layers) ] + + +class DummySpeechT5InputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("output_sequence", "speaker_embeddings", "spectrogram") + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + self.task = task + self.batch_size = 1 # TODO: SpeechT5 does not support batch inference in Transformers for now. + + self.sequence_length = sequence_length + self.speaker_embedding_dim = normalized_config.speaker_embedding_dim + self.num_mel_bins = normalized_config.num_mel_bins + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "output_sequence": + shape = [self.batch_size, self.sequence_length, self.num_mel_bins] + elif input_name == "speaker_embeddings": + shape = [self.batch_size, self.speaker_embedding_dim] + elif input_name == "spectrogram": + shape = [20, self.num_mel_bins] # NOTE: the first axis length is arbitrary and dynamic + else: + raise ValueError(f"Unsupported input {input_name} for DummySpeechT5InputGenerator") + + return self.random_float_tensor( + shape=shape, + min_value=0, + max_value=1, + framework=framework, + dtype=float_dtype, + ) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ba2030d6742..55f588fc012 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -152,6 +152,7 @@ # Disabled for now because some operator seems to not be supported by ONNX. # "mctct": "hf-internal-testing/tiny-random-MCTCTModel", "speech-to-text": "hf-internal-testing/tiny-random-Speech2TextModel", + "speecht5": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", "vision-encoder-decoder": { diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index efdbaba4235..07c6fa292e9 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -160,6 +160,7 @@ def _onnx_export( device: str = "cpu", fp16: bool = False, variant: str = "default", + model_kwargs: Optional[Dict] = None, ): with TemporaryDirectory() as tmpdir: try: @@ -173,6 +174,7 @@ def _onnx_export( monolith=monolith, no_post_process=no_post_process, _variant=variant, + model_kwargs=model_kwargs, ) except MinimumVersionError as e: pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") @@ -276,7 +278,12 @@ def test_exporters_cli_pytorch_cpu( # masked-im models use MaskedImageModelingOutput if model_type in ["vit", "deit"] and task == "masked-im": self.skipTest("Temporarily disabled upon transformers 4.28 release") - self._onnx_export(model_name, task, monolith, no_post_process, variant=variant) + + model_kwargs = None + if model_type == "speecht5": + model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} + + self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, model_kwargs=model_kwargs) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_vision @@ -300,7 +307,13 @@ def test_exporters_cli_pytorch_gpu( if model_type == "sam": self.skipTest("sam export on cuda is not supported due to a bug in PyTorch") - self._onnx_export(model_name, task, monolith, no_post_process, device="cuda", variant=variant) + model_kwargs = None + if model_type == "speecht5": + model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} + + self._onnx_export( + model_name, task, monolith, no_post_process, device="cuda", variant=variant, model_kwargs=model_kwargs + ) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_torch @@ -317,10 +330,20 @@ def test_exporters_cli_pytorch_with_optimization( monolith: bool, no_post_process: bool, ): + model_kwargs = None + if model_type == "speecht5": + model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} + for optimization_level in ["O1", "O2", "O3"]: try: self._onnx_export( - model_name, task, monolith, no_post_process, optimization_level=optimization_level, variant=variant + model_name, + task, + monolith, + no_post_process, + optimization_level=optimization_level, + variant=variant, + model_kwargs=model_kwargs, ) except NotImplementedError as e: if "Tried to use ORTOptimizer for the model type" in str( @@ -354,9 +377,20 @@ def test_exporters_cli_pytorch_with_O4_optimization( if model_type == "sam": self.skipTest("sam export on cuda is not supported due to a bug in PyTorch") + model_kwargs = None + if model_type == "speecht5": + model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} + try: self._onnx_export( - model_name, task, monolith, no_post_process, optimization_level="O4", device="cuda", variant=variant + model_name, + task, + monolith, + no_post_process, + optimization_level="O4", + device="cuda", + variant=variant, + model_kwargs=model_kwargs, ) except NotImplementedError as e: if "Tried to use ORTOptimizer for the model type" in str( @@ -474,6 +508,10 @@ def test_export_on_fp16( if model_type == "ibert": self.skipTest("ibert can not be supported in fp16") + # TODO: test once https://github.com/pytorch/pytorch/pull/110078 is fixed + if model_type == "speecht5": + self.skipTest("speecht5 can not be supported in fp16 due to a pytorch bug") + self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, fp16=True, device="cuda") @parameterized.expand( diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 11e6a53da36..eba2f01f61a 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -40,6 +40,7 @@ from optimum.exporters.onnx.base import ConfigBehavior from optimum.exporters.onnx.config import TextDecoderOnnxConfig from optimum.exporters.onnx.model_configs import WhisperOnnxConfig +from optimum.exporters.onnx.utils import get_speecht5_models_for_export from optimum.utils import ONNX_WEIGHTS_NAME, DummyPastKeyValuesGenerator, NormalizedTextConfig from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_timm @@ -216,6 +217,7 @@ def _onnx_export( if isinstance(atol, dict): atol = atol[task.replace("-with-past", "")] + model_kwargs = None if ( model.config.is_encoder_decoder and task.startswith( @@ -231,6 +233,9 @@ def _onnx_export( models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) elif task.startswith("text-generation") and monolith is False: models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) + elif model.config.model_type == "speecht5": + model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"} + models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs) else: models_and_onnx_configs = {"model": (model, onnx_config)} @@ -240,6 +245,7 @@ def _onnx_export( opset=onnx_config.DEFAULT_ONNX_OPSET, output_dir=Path(tmpdirname), device=device, + model_kwargs=model_kwargs, ) input_shapes_iterator = grid_parameters(shapes_to_validate, yield_dict=True, add_test_name=False) for input_shapes in input_shapes_iterator: @@ -269,6 +275,7 @@ def _onnx_export( output_dir=Path(tmpdirname), input_shapes=input_shapes, device=device, + model_kwargs=model_kwargs, ) except AtolError as e: print(f"The ONNX export succeeded with the warning: {e}") @@ -318,15 +325,18 @@ def test_all_models_tested(self): def test_pytorch_export_on_cpu( self, test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, monolith: bool, ): + if model_type == "speecht5" and monolith: + self.skipTest("unsupported export") + self._onnx_export( test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, @@ -344,15 +354,18 @@ def test_pytorch_export_on_cpu( def test_pytorch_export_on_cuda( self, test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, monolith: bool, ): + if model_type == "speecht5" and monolith: + self.skipTest("unsupported export") + self._onnx_export( test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, @@ -367,11 +380,13 @@ def test_pytorch_export_on_cuda( @require_tf @require_vision @pytest.mark.tensorflow_test - def test_tensorflow_export(self, test_name, name, model_name, task, onnx_config_class_constructor, monolith: bool): + def test_tensorflow_export( + self, test_name, model_type, model_name, task, onnx_config_class_constructor, monolith: bool + ): if monolith is False: return 0 - self._onnx_export(test_name, name, model_name, task, onnx_config_class_constructor, monolith=monolith) + self._onnx_export(test_name, model_type, model_name, task, onnx_config_class_constructor, monolith=monolith) @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) @require_torch @@ -401,7 +416,7 @@ def test_pytorch_export_for_stable_diffusion_models_cuda(self, model_type, model def test_pytorch_export_for_timm_on_cpu( self, test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, @@ -409,7 +424,7 @@ def test_pytorch_export_for_timm_on_cpu( ): self._onnx_export( test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, @@ -429,7 +444,7 @@ def test_pytorch_export_for_timm_on_cpu( def test_pytorch_export_for_timm_on_cuda( self, test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, @@ -437,7 +452,7 @@ def test_pytorch_export_for_timm_on_cuda( ): self._onnx_export( test_name, - name, + model_type, model_name, task, onnx_config_class_constructor, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 7794bd7b2d4..6bcbf111e9c 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1090,7 +1090,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForQuestionAnswering.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("custom or unsupported architecture", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -1252,7 +1252,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForMaskedLM.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -1409,7 +1409,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForSequenceClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("that is a custom or unsupported", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -1582,7 +1582,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForTokenClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -1994,7 +1994,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForCausalLM.from_pretrained(MODEL_NAMES["vit"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): @@ -2400,7 +2400,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForImageClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -2540,7 +2540,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForSemanticSegmentation.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -2695,7 +2695,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForAudioClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -2847,7 +2847,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForCTC.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -2906,7 +2906,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForAudioXVector.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -2998,7 +2998,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForAudioFrameClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -3087,7 +3087,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForSeq2SeqLM.from_pretrained(MODEL_NAMES["bert"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): @@ -3697,7 +3697,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForSpeechSeq2Seq.from_pretrained(MODEL_NAMES["bert"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): @@ -4066,7 +4066,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForVision2Seq.from_pretrained(MODEL_NAMES["bert"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand( grid_parameters( @@ -4480,7 +4480,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForPix2Struct.from_pretrained(MODEL_NAMES["bert"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch):