Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the export of only one decoder #1257

Merged
merged 84 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
41b8f98
ONNX export decoder model refactorization
echarlaix Aug 3, 2023
f91a018
fix style
echarlaix Aug 4, 2023
4ce5fbe
fix index
echarlaix Aug 4, 2023
552eebc
merge main in branch
echarlaix Sep 8, 2023
aa40ba4
Merge branch 'main' into refactorization-decoder-ort
echarlaix Sep 12, 2023
9fa05e4
fix IO bindings
echarlaix Sep 12, 2023
3a0d76a
format
echarlaix Sep 12, 2023
b0aa234
enable mpt support
echarlaix Sep 12, 2023
dfabefd
format
echarlaix Sep 12, 2023
35df7bd
add trust remote code
echarlaix Sep 13, 2023
469edc8
fix test
echarlaix Sep 13, 2023
77cc527
format
echarlaix Sep 13, 2023
4f72a7e
rm redundant
echarlaix Sep 13, 2023
599c31c
format
echarlaix Sep 13, 2023
dac2376
merge main in branch
echarlaix Sep 13, 2023
c13b645
fix
echarlaix Sep 13, 2023
0e83cd1
Merge branch 'main' into refactorization-decoder-ort
echarlaix Sep 14, 2023
1f81f0b
Merge branch 'main' into refactorization-decoder-ort
echarlaix Sep 14, 2023
a0d0802
fix quantization
echarlaix Sep 14, 2023
7f65ce1
add test
echarlaix Sep 14, 2023
2840b81
format
echarlaix Sep 14, 2023
5fa7b20
format
echarlaix Sep 14, 2023
8011982
fix optimization
echarlaix Sep 14, 2023
b643308
fix opitmization
echarlaix Sep 15, 2023
ca9ce30
fix compatibility with legacy models
echarlaix Sep 15, 2023
144753a
format
echarlaix Sep 15, 2023
4ee6167
fix legacy models
echarlaix Sep 15, 2023
f2d0f84
format
echarlaix Sep 15, 2023
3ff719a
fix style
echarlaix Sep 15, 2023
d794141
format
echarlaix Sep 15, 2023
a34a16e
add export to main_export
echarlaix Sep 15, 2023
dfe7e5e
add legacy to ONNX export
echarlaix Sep 18, 2023
8d102f7
fix test
echarlaix Sep 18, 2023
62b8974
fix
echarlaix Sep 18, 2023
b8e18c3
rm unused import
echarlaix Sep 18, 2023
819691e
patch model to fix causal lm generation
echarlaix Sep 18, 2023
e259670
rm commen
echarlaix Sep 18, 2023
2f26201
add no psot process
echarlaix Sep 18, 2023
bed73d4
merge main in branch
echarlaix Sep 18, 2023
6d8acb4
fix
echarlaix Sep 18, 2023
52c1745
remove bloom caching
echarlaix Sep 18, 2023
1e9ba7e
fix
echarlaix Sep 19, 2023
4b68caa
format
echarlaix Sep 19, 2023
e5fd9f8
fix dynamic axis for position ids
echarlaix Sep 19, 2023
addad92
fix external data
echarlaix Sep 19, 2023
2c063c0
format
echarlaix Sep 19, 2023
1b47093
test
echarlaix Sep 19, 2023
35caaf2
test
echarlaix Sep 19, 2023
725857b
add model patcher
echarlaix Sep 19, 2023
46b26b5
format
echarlaix Sep 19, 2023
33957af
fix
echarlaix Sep 19, 2023
c2ec382
fix bart model patcher
echarlaix Sep 19, 2023
d86bce6
format
echarlaix Sep 19, 2023
be836b5
format
echarlaix Sep 20, 2023
b05f599
fix model patcher for opt models
echarlaix Sep 20, 2023
26d97e8
fix format
echarlaix Sep 20, 2023
4b6c3ed
add tmp onnxruntime max version
echarlaix Sep 20, 2023
615a219
add test
echarlaix Sep 20, 2023
b3525f8
format
echarlaix Sep 20, 2023
e0e2bae
tmp fix onnxruntime max version
echarlaix Sep 20, 2023
cbc935f
format
echarlaix Sep 20, 2023
624d91d
add test
echarlaix Sep 20, 2023
c558450
fix ort docker
echarlaix Sep 20, 2023
e72526d
fix format
echarlaix Sep 20, 2023
7926999
merge main in branch
echarlaix Sep 22, 2023
44ef0f1
add test
echarlaix Sep 22, 2023
ed8e74f
fix bart model patcher
echarlaix Sep 25, 2023
c13a170
raise when unsupported model
echarlaix Sep 25, 2023
524b668
add cached file
echarlaix Sep 25, 2023
8951ddf
minor
echarlaix Oct 3, 2023
2491ef3
add position warning
echarlaix Oct 4, 2023
0ab6e61
fixes
echarlaix Oct 5, 2023
1a7d491
enable post process after export to remove tied weights
echarlaix Oct 5, 2023
cd8d4be
comment
echarlaix Oct 5, 2023
e6de5e7
remove test
echarlaix Oct 5, 2023
4a32f7a
fix test
echarlaix Oct 5, 2023
a51686e
modify model
echarlaix Oct 6, 2023
e2f8a3b
remove deprecated use_merged in test
echarlaix Oct 6, 2023
52ce2d7
Merge branch 'main' into refactorization-decoder-ort
echarlaix Oct 9, 2023
b76f43a
Add mistral model patcher
echarlaix Oct 9, 2023
5b3d445
fix test
echarlaix Oct 9, 2023
5406f95
add slow test
echarlaix Oct 9, 2023
52e0c69
add workflow
echarlaix Oct 9, 2023
8883323
fix
echarlaix Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ def parse_args_onnx(parser):
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
optional_group.add_argument(
"--no-position-ids",
action="store_true",
help=(
"Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
)
Expand Down Expand Up @@ -217,6 +209,14 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
optional_group.add_argument(
"--legacy",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could merge the no_position_ids and legacy as they correspond to the previous export behavior and no_position_ids is not in a release yet. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good idea, will merge both

action="store_true",
help=(
"Export decoder only models in three files (without + with past and the resulting merged model)."
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -255,6 +255,6 @@ def run(self):
use_subprocess=True,
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
legacy=self.args.legacy,
**input_shapes,
)
21 changes: 11 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_submodels_and_onnx_configs(
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
legacy: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -82,8 +82,8 @@ def _get_submodels_and_onnx_configs(
model=model, exporter="onnx", task=task
)
onnx_config_kwargs = {}
if task.startswith("text-generation") and no_position_ids:
onnx_config_kwargs["no_position_ids"] = no_position_ids
if task.startswith("text-generation") and legacy:
onnx_config_kwargs["no_position_ids"] = legacy

onnx_config = onnx_config_constructor(
model.config,
Expand All @@ -106,7 +106,7 @@ def _get_submodels_and_onnx_configs(
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
else:
Expand Down Expand Up @@ -184,7 +184,7 @@ def main_export(
use_subprocess: bool = False,
_variant: str = "default",
library_name: Optional[str] = None,
no_position_ids: bool = False,
legacy: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -264,8 +264,8 @@ def main_export(
library_name (`Optional[str]`, defaults to `None`):
The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect
the library name for the checkpoint.
no_position_ids (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -353,9 +353,9 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

if no_position_ids and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
logger.warning(
f"no_position_ids=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `no_position_ids=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -424,7 +424,7 @@ def main_export(
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
legacy=legacy,
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -610,6 +610,7 @@ def main():
pad_token_id=args.pad_token_id,
for_ort=args.for_ort,
library_name=args.library_name,
legacy=args.legacy,
**input_shapes,
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
elif self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
Expand Down
7 changes: 2 additions & 5 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
else:
Expand Down Expand Up @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
echarlaix marked this conversation as resolved.
Show resolved Hide resolved

return common_inputs

Expand Down
46 changes: 41 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import (
BartModelPatcher,
BloomModelPatcher,
LlamaModelPatcher,
MistralModelPatcher,
OPTModelPatcher,
SAMModelPatcher,
WavLMModelPatcher,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -216,13 +224,23 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return OPTModelPatcher(self, model, model_kwargs=model_kwargs)


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
Expand All @@ -233,6 +251,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand All @@ -241,6 +264,11 @@ class MPTOnnxConfig(TextDecoderOnnxConfig):
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)


class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
Expand Down Expand Up @@ -274,6 +302,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
1: decoder_sequence_name,
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)
echarlaix marked this conversation as resolved.
Show resolved Hide resolved


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down Expand Up @@ -413,7 +446,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return int_tensor


class BartOnnxConfig(TextSeq2SeqOnnxConfig):
class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
Expand Down Expand Up @@ -537,11 +570,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
)


class MBartOnnxConfig(BartOnnxConfig):
pass
class BartOnnxConfig(M2M100OnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BartModelPatcher(self, model, model_kwargs=model_kwargs)


class M2M100OnnxConfig(BartOnnxConfig):
class MBartOnnxConfig(BartOnnxConfig):
pass


Expand Down
95 changes: 95 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

from transformers.utils import is_torch_available

from ...utils.modeling_utils import (
_prepare_attn_mask,
_prepare_decoder_attention_mask,
_prepare_decoder_sliding_window_attention_mask,
)


if is_torch_available():
import torch
Expand Down Expand Up @@ -342,3 +348,92 @@ def patched_forward(
return {"iou_scores": iou_predictions, "pred_masks": low_res_masks}

self.patched_forward = patched_forward


class CausalAttentionMaskModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self._orig_func = getattr(self._model_to_patch, self._orig_func_name)

def __enter__(self):
super().__enter__()
if self.patch:
setattr(self._model_to_patch, self._orig_func_name, self._patch_func.__get__(self._model_to_patch))

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch))


class BloomModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.transformer
self._patch_func = _prepare_attn_mask
self._orig_func_name = "_prepare_attn_mask"
super().__init__(config, model, model_kwargs)


class OPTModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)


class LlamaModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)


class MistralModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_sliding_window_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)


class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)
Loading