Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mistral models ONNX export support #1425

Merged
merged 12 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Supported architectures:
- M2-M100
- Marian
- MBart
- Mistral
- MobileBert
- MobileVit
- MobileNet v1
Expand Down
54 changes: 13 additions & 41 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DummyVisionEmbeddingsGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
Expand Down Expand Up @@ -216,52 +217,23 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.num_key_value_heads = normalized_config.num_key_value_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, LlamaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = LlamaDummyPastKeyValuesGenerator

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
DUMMY_INPUT_GENERATOR_CLASSES = (
MistralDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"gptj",
"imagegpt",
"llama",
"mistral",
}


Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,14 @@ class TasksManager:
"question-answering",
onnx="MBartOnnxConfig",
),
"mistral": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification",
onnx="MistralOnnxConfig",
),
# TODO: enable once the missing operator is supported.
# "mctct": supported_tasks_mapping(
# "feature-extraction",
Expand Down
17 changes: 12 additions & 5 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def prepare_inputs_for_merged(
# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.shape[0]
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads

if self.normalized_config.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

dtype = constructor.float16 if self.use_fp16 else constructor.float32
# TODO: find a way to better handle this controlflow, this is EXTREMELY ugly
Expand Down Expand Up @@ -277,8 +281,11 @@ def compute_past_key_values_output_shapes(
`Dict[str, List[int]]`: The dictionary mapping each past key value output name to its corresponding shape.
"""
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
if self.normalized_config.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

sequence_length = input_ids.size(1)
if past_key_values is not None and use_cache_branch is not False:
Expand Down Expand Up @@ -527,7 +534,7 @@ def compute_past_key_values_output_shapes(
) -> Dict[str, int]:
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

sequence_length = input_ids.size(1)
encoder_sequence_length = encoder_hidden_states.size(1)
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ORTConfigManager:
"llama": "gpt2",
"marian": "bart",
"mbart": "bart",
"mistral": "gpt2",
"mt5": "bart",
"m2m-100": "bart",
"nystromformer": "bert",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
DummyVisionInputGenerator,
FalconDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
)
from .modeling_utils import recurse_getattr, recurse_setattr
from .normalized_config import (
Expand Down
37 changes: 37 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,3 +908,40 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
)
for _ in range(self.num_layers)
]


class MistralDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
3 changes: 3 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def __getattr__(self, attr_name):
allow_new=True,
)

MistralNormalizedTextConfig = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)


class NormalizedConfigManager:
"""
Expand Down Expand Up @@ -234,6 +236,7 @@ class NormalizedConfigManager:
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mistral": MistralNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"m2m-100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilenet-v1": "google/mobilenet_v1_0.75_192",
Expand Down
8 changes: 2 additions & 6 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neox",
"gptj",
"llama",
"mistral",
"mpt",
]

Expand Down Expand Up @@ -2084,10 +2085,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
transformers_model = transformers_model.eval()
tokenizer = get_preprocessor(model_id)
tokens = tokenizer(
"This is a sample output",
return_tensors="pt",
)
tokens = tokenizer("This is a sample output", return_tensors="pt")
position_ids = None
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
input_shape = tokens["input_ids"].shape
Expand Down Expand Up @@ -2146,7 +2144,6 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo
use_cache=use_cache,
use_io_binding=use_io_binding,
)

tokenizer = get_preprocessor(model_id)
pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer)
text = "My Name is Philipp and i live"
Expand Down Expand Up @@ -2210,7 +2207,6 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st
)

tokenizer = get_preprocessor(model_id)

# build engine for a short sequence
text = ["short"]
encoded_input = tokenizer(
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
Expand Down