diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 9e57128c272..cc752779d30 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -298,6 +298,11 @@ class GemmaOnnxConfig(LlamaOnnxConfig): pass +class GraniteOnnxConfig(LlamaOnnxConfig): + MIN_TRANSFORMERS_VERSION = version.parse("4.45.0") + MIN_TORCH_VERSION = version.parse("2.5.0") + + class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 56249bbf5c3..19e24f88743 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -86,6 +86,7 @@ "phi", "phi3", "qwen2", + "granite", } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index a489f34fb06..fdc8bfcb539 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -915,6 +915,13 @@ class TasksManager: "text-classification", onnx="LlamaOnnxConfig", ), + "granite": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + onnx="GraniteOnnxConfig", + ), "pegasus": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 984d7f22ebf..8f1d062221a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -340,7 +340,7 @@ def prepare_past_key_values( if self.model_type == "gemma": num_attention_heads = self.normalized_config.num_key_value_heads embed_size_per_head = self.normalized_config.head_dim - elif self.model_type in {"mistral", "llama", "qwen2"}: + elif self.model_type in {"mistral", "llama", "qwen2", "granite"}: num_attention_heads = self.normalized_config.num_key_value_heads else: num_attention_heads = self.normalized_config.num_attention_heads diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 128e2406f11..9e92e0bd325 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -128,6 +128,7 @@ class ORTConfigManager: "gpt-neo": "gpt2", "gpt-neox": "gpt2", "gptj": "gpt2", + "granite": "gpt2", # longt5 with O4 results in segmentation fault "longt5": "bert", "llama": "gpt2", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 81207b76496..9ceed24c2dd 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -281,6 +281,7 @@ class NormalizedConfigManager: "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, "qwen2": NormalizedTextConfig, + "granite": NormalizedTextConfigWithGQA, } @classmethod diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c8a33b0be35..ccccb5510bf 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -100,6 +100,7 @@ "gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", "groupvit": "hf-internal-testing/tiny-random-groupvit", "ibert": "hf-internal-testing/tiny-random-IBertModel", "imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 597eb581e2a..a335e014478 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2324,6 +2324,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "gpt_neo", "gpt_neox", "gptj", + "granite", "llama", "mistral", "mpt", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index e3d54237857..9f200e69b3d 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -104,6 +104,7 @@ "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM", + "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", "groupvit": "hf-internal-testing/tiny-random-groupvit", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-IBertModel",