Skip to content

Commit

Permalink
add export to main_export
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 15, 2023
1 parent d794141 commit a34a16e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 44 deletions.
69 changes: 48 additions & 21 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
logging,
)
from ...utils.import_utils import _diffusers_version
from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask

from ..tasks import TasksManager
from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME

Expand Down Expand Up @@ -146,16 +148,28 @@ def _get_submodels_for_export_stable_diffusion(


def _get_submodels_for_export_decoder(
model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool
model: Union["PreTrainedModel", "TFPreTrainedModel"],
use_past: bool,
legacy: bool = False,
) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]:
"""
Returns the decoder part of the model.
"""
models_for_export = {}

models_for_export[ONNX_DECODER_NAME] = model
if use_past:
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model
if legacy:
models_for_export[ONNX_DECODER_NAME] = model
if use_past:
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model
else:
if model.config.model_type in {"bloom", "mpt"}:
model.transformer._prepare_attn_mask = _prepare_attn_mask
elif model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
elif model.config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"):
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

models_for_export["model"] = model

return models_for_export

Expand Down Expand Up @@ -214,6 +228,7 @@ def get_encoder_decoder_models_for_export(
def get_decoder_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: "OnnxConfig",
legacy: bool = False,
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]:
"""
Returns two versions of the decoder that can be used together to perform fast generation:
Expand All @@ -233,31 +248,43 @@ def get_decoder_models_for_export(
`Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and
onnx configs for the encoder and decoder parts of the model.
"""
models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past)

onnx_config = config.__class__(
model.config,
task=config.task,
use_past=config.use_past,
use_past_in_inputs=False,
float_dtype=config.float_dtype,
int_dtype=config.int_dtype,
)
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config)
models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past)

if config.use_past:
onnx_config_with_past = config.__class__(
if legacy:
onnx_config = config.__class__(
model.config,
task=config.task,
use_past=True,
use_past_in_inputs=True,
use_past=config.use_past,
use_past_in_inputs=False,
float_dtype=config.float_dtype,
int_dtype=config.int_dtype,
)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
onnx_config_with_past,
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config)

if config.use_past:
onnx_config_with_past = config.__class__(
model.config,
task=config.task,
use_past=True,
use_past_in_inputs=True,
float_dtype=config.float_dtype,
int_dtype=config.int_dtype,
)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
onnx_config_with_past,
)
else:
onnx_config = config.__class__(
model.config,
task=config.task,
use_past=config.use_past,
use_past_in_inputs=config.use_past,
float_dtype=config.float_dtype,
int_dtype=config.int_dtype,
)
models_for_export["model"] = (models_for_export["model"], onnx_config)

return models_for_export

Expand Down
40 changes: 17 additions & 23 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,33 +1024,27 @@ def _from_transformers(
if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)

if use_cache:
task += "-with-past"

save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
model_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config, use_past=use_cache, use_past_in_inputs=use_cache)

if config.model_type in {"bloom", "mpt"}:
model.transformer._prepare_attn_mask = _prepare_attn_mask
elif config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
elif config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"):
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

# Export the model to the ONNX format
export(model=model, config=onnx_config, output=save_dir_path / file_name)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
do_validation=False,
no_post_process=False,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)

# TODO : use main_export
config.save_pretrained(save_dir_path)
maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)

Expand Down

0 comments on commit a34a16e

Please sign in to comment.