diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 3170ebdcdd2..1c86d8a86e6 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -29,6 +29,8 @@ logging, ) from ...utils.import_utils import _diffusers_version +from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask + from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME @@ -146,16 +148,28 @@ def _get_submodels_for_export_stable_diffusion( def _get_submodels_for_export_decoder( - model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool + model: Union["PreTrainedModel", "TFPreTrainedModel"], + use_past: bool, + legacy: bool = False, ) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]: """ Returns the decoder part of the model. """ models_for_export = {} - models_for_export[ONNX_DECODER_NAME] = model - if use_past: - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model + if legacy: + models_for_export[ONNX_DECODER_NAME] = model + if use_past: + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model + else: + if model.config.model_type in {"bloom", "mpt"}: + model.transformer._prepare_attn_mask = _prepare_attn_mask + elif model.config.model_type == "llama": + model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif model.config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"): + model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + + models_for_export["model"] = model return models_for_export @@ -214,6 +228,7 @@ def get_encoder_decoder_models_for_export( def get_decoder_models_for_export( model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", + legacy: bool = False, ) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: """ Returns two versions of the decoder that can be used together to perform fast generation: @@ -233,31 +248,43 @@ def get_decoder_models_for_export( `Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and onnx configs for the encoder and decoder parts of the model. """ - models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) - onnx_config = config.__class__( - model.config, - task=config.task, - use_past=config.use_past, - use_past_in_inputs=False, - float_dtype=config.float_dtype, - int_dtype=config.int_dtype, - ) - models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) + models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) - if config.use_past: - onnx_config_with_past = config.__class__( + if legacy: + onnx_config = config.__class__( model.config, task=config.task, - use_past=True, - use_past_in_inputs=True, + use_past=config.use_past, + use_past_in_inputs=False, float_dtype=config.float_dtype, int_dtype=config.int_dtype, ) - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( - models_for_export[ONNX_DECODER_WITH_PAST_NAME], - onnx_config_with_past, + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) + + if config.use_past: + onnx_config_with_past = config.__class__( + model.config, + task=config.task, + use_past=True, + use_past_in_inputs=True, + float_dtype=config.float_dtype, + int_dtype=config.int_dtype, + ) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + onnx_config_with_past, + ) + else: + onnx_config = config.__class__( + model.config, + task=config.task, + use_past=config.use_past, + use_past_in_inputs=config.use_past, + float_dtype=config.float_dtype, + int_dtype=config.int_dtype, ) + models_for_export["model"] = (models_for_export["model"], onnx_config) return models_for_export diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 31410121b53..26b693b4e05 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1024,33 +1024,27 @@ def _from_transformers( if task is None: task = cls._auto_model_to_task(cls.auto_model_class) + if use_cache: + task += "-with-past" + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) - model_kwargs = { - "revision": revision, - "use_auth_token": use_auth_token, - "cache_dir": cache_dir, - "subfolder": subfolder, - "local_files_only": local_files_only, - "force_download": force_download, - "trust_remote_code": trust_remote_code, - } - - model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config, use_past=use_cache, use_past_in_inputs=use_cache) - if config.model_type in {"bloom", "mpt"}: - model.transformer._prepare_attn_mask = _prepare_attn_mask - elif config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - elif config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"): - model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - - # Export the model to the ONNX format - export(model=model, config=onnx_config, output=save_dir_path / file_name) + main_export( + model_name_or_path=model_id, + output=save_dir_path, + task=task, + do_validation=False, + no_post_process=False, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) - # TODO : use main_export config.save_pretrained(save_dir_path) maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)