Skip to content

Commit

Permalink
Add ONNX export support for granite models (#2043)
Browse files Browse the repository at this point in the history
* feat(exporters/onnx): Add GraniteOnnxConfig and task support list

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add granite's normalized config for inference

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(onnx opt): Add onnx optimization support for granite

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(onnx/granite): Use LlamaOnnxConfig as the base for GraniteOnnxConfig

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(onnxruntime): Add "granite" to list of model types with grouped attention

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Add granite to the list of models that require position_ids

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granite): Add MIN_TORCH_VERSION for recently fixed torch bug

#2043 (comment)

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* test(granite): Add tiny random granite test for onnx exporter

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

* tests(onnxruntime): Add granite to onnxruntime tests

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart authored Oct 31, 2024
1 parent 6802a0c commit 7e8d857
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 1 deletion.
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
pass


class GraniteOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
MIN_TORCH_VERSION = version.parse("2.5.0")


class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"phi",
"phi3",
"qwen2",
"granite",
}


Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,13 @@ class TasksManager:
"text-classification",
onnx="LlamaOnnxConfig",
),
"granite": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
onnx="GraniteOnnxConfig",
),
"pegasus": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def prepare_past_key_values(
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"mistral", "llama", "qwen2"}:
elif self.model_type in {"mistral", "llama", "qwen2", "granite"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ORTConfigManager:
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
"granite": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"qwen2": NormalizedTextConfig,
"granite": NormalizedTextConfigWithGQA,
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
"imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neo",
"gpt_neox",
"gptj",
"granite",
"llama",
"mistral",
"mpt",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
Expand Down

0 comments on commit 7e8d857

Please sign in to comment.