Skip to content

Commit

Permalink
Add position ids in ONNX export and ORT (#1381)
Browse files Browse the repository at this point in the history
* add position_ids support

* fix iobinding

* add position_ids

Co-authored-by: Kunal Vaishnavi <[email protected]>
Co-authored-by: "Feng, Jiqing" <[email protected]>

* remove outdated comment

* backward compatibility

* tests wip

* tests are fine

* fix tests

* test fix bis

* update on review

---------

Co-authored-by: Kunal Vaishnavi <[email protected]>
Co-authored-by: "Feng, Jiqing" <[email protected]>
  • Loading branch information
3 people authored Sep 18, 2023
1 parent 915e182 commit 89d08c4
Show file tree
Hide file tree
Showing 13 changed files with 237 additions and 62 deletions.
22 changes: 15 additions & 7 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def parse_args_onnx(parser):
" it."
),
)
optional_group.add_argument(
"--library-name",
type=str,
choices=["transformers", "diffusers", "timm"],
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
optional_group.add_argument(
"--no-position-ids",
action="store_true",
help=(
"Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -203,13 +217,6 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
optional_group.add_argument(
"--library_name",
type=str,
choices=["transformers", "diffusers", "timm"],
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -248,5 +255,6 @@ def run(self):
use_subprocess=True,
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
**input_shapes,
)
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"get_decoder_models_for_export",
"get_encoder_decoder_models_for_export",
"get_stable_diffusion_models_for_export",
"MODEL_TYPES_REQUIRING_POSITION_IDS",
],
"__main__": ["main_export"],
}
Expand All @@ -38,6 +39,7 @@
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
MODEL_TYPES_REQUIRING_POSITION_IDS,
)
from .__main__ import main_export
else:
Expand Down
21 changes: 20 additions & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .constants import UNPICKABLE_ARCHS
from .convert import export_models, validate_models_outputs
from .utils import (
MODEL_TYPES_REQUIRING_POSITION_IDS,
_get_submodels_for_export_decoder,
_get_submodels_for_export_encoder_decoder,
_get_submodels_for_export_stable_diffusion,
Expand Down Expand Up @@ -67,6 +68,7 @@ def _get_submodels_and_onnx_configs(
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -79,8 +81,16 @@ def _get_submodels_and_onnx_configs(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=task
)
onnx_config_kwargs = {}
if task.startswith("text-generation") and no_position_ids:
onnx_config_kwargs["no_position_ids"] = no_position_ids

onnx_config = onnx_config_constructor(
model.config, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
**onnx_config_kwargs,
)

onnx_config.variant = _variant
Expand Down Expand Up @@ -174,6 +184,7 @@ def main_export(
use_subprocess: bool = False,
_variant: str = "default",
library_name: Optional[str] = None,
no_position_ids: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -253,6 +264,8 @@ def main_export(
library_name (`Optional[str]`, defaults to `None`):
The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect
the library name for the checkpoint.
no_position_ids (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -340,6 +353,11 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

if no_position_ids and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
logger.warning(
f"no_position_ids=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `no_position_ids=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
)

if not is_stable_diffusion:
if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
raise ValueError(
Expand Down Expand Up @@ -406,6 +424,7 @@ def main_export(
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
)

if not is_stable_diffusion:
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def overwrite_shape_and_generate_input(
self.use_past
and self.use_past_in_inputs
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids"]
and input_name in ["decoder_input_ids", "input_ids", "position_ids"]
):
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
Expand Down
40 changes: 40 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ class TextDecoderOnnxConfig(OnnxConfigWithPast):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
)
# TODO: remove no_position_ids once optimum is sufficiently above 1.13
self.no_position_ids = no_position_ids

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
Expand Down Expand Up @@ -132,6 +155,23 @@ def post_process_exported_models(
return models_and_onnx_configs, onnx_files_subpaths


class TextDecoderWithPositionIdsOnnxConfig(TextDecoderOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs

# Decoders based on GPT2 require a position_ids input to avoid
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs


class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
"""
Handles encoder-decoder-based text architectures.
Expand Down
14 changes: 9 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
EncoderDecoderBaseOnnxConfig,
TextAndVisionOnnxConfig,
TextDecoderOnnxConfig,
TextDecoderWithPositionIdsOnnxConfig,
TextEncoderOnnxConfig,
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
Expand Down Expand Up @@ -172,7 +173,7 @@ class DebertaV2OnnxConfig(DebertaOnnxConfig):
pass


class GPT2OnnxConfig(TextDecoderOnnxConfig):
class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

Expand All @@ -199,34 +200,37 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):
pass


class GPTNeoOnnxConfig(TextDecoderOnnxConfig):
class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")


class GPTNeoXOnnxConfig(TextDecoderOnnxConfig):
class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class OPTOnnxConfig(TextDecoderOnnxConfig):
# OPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaOnnxConfig(TextDecoderOnnxConfig):
class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)


class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
DUMMY_INPUT_GENERATOR_CLASSES = (
BloomDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
Expand Down Expand Up @@ -258,7 +262,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
}


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
Expand Down
24 changes: 18 additions & 6 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
from diffusers import ModelMixin, StableDiffusionPipeline


MODEL_TYPES_REQUIRING_POSITION_IDS = {
"codegen",
"gpt2",
"gpt-bigcode",
"gpt-neo",
"gpt-neox",
"gptj",
"imagegpt",
"llama",
}


def check_onnxruntime_requirements(minimum_version: version.Version):
"""
Checks that ONNX Runtime is installed and if version is recent enough.
Expand Down Expand Up @@ -235,24 +247,24 @@ def get_decoder_models_for_export(
"""
models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past)

onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype}
if model.config.model_type.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
onnx_kwargs["no_position_ids"] = config.no_position_ids

onnx_config = config.__class__(
model.config,
task=config.task,
use_past=config.use_past,
use_past_in_inputs=False,
float_dtype=config.float_dtype,
int_dtype=config.int_dtype,
**onnx_kwargs,
)
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config)

if config.use_past:
onnx_config_with_past = config.__class__(
model.config,
task=config.task,
use_past=True,
use_past_in_inputs=True,
float_dtype=config.float_dtype,
int_dtype=config.int_dtype,
**onnx_kwargs,
)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
Expand Down
23 changes: 20 additions & 3 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: None = None,
Expand Down Expand Up @@ -356,15 +357,21 @@ def forward(
past_key_values=past_key_values,
)

# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# TODO: fix transformers generate to have contiguous input_ids, position_ids here already
# Calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.
# I suspect the reason is the contiguous python list that messes something up?
# I suspect the garbage collector to somehow negate `tensor = tensor.contiguous()`
# in modeling_ort.py, which is then never assigned anywhere.
model_inputs = [input_ids.contiguous()]

if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
model_inputs.append(position_ids.contiguous())

if past_key_values is not None:
model_inputs += past_key_values

Expand Down Expand Up @@ -421,6 +428,11 @@ def forward(
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy()

if "labels" in self.input_names:
onnx_inputs["labels"] = labels.cpu().detach().numpy()
else:
Expand All @@ -437,6 +449,11 @@ def forward(
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value

if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
onnx_inputs["position_ids"] = position_ids

if "labels" in self.input_names:
onnx_inputs["labels"] = labels

Expand Down
Loading

0 comments on commit 89d08c4

Please sign in to comment.