Skip to content

Commit

Permalink
Support qwen2 family model (qwen1.5) (#1746)
Browse files Browse the repository at this point in the history
* Support qwen2 family model (qwen1.5)

* update docs

* add tests for qwen2

* fix test

* ordering

---------

Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
uniartisan and fxmarty authored Mar 20, 2024
1 parent cf82249 commit e6641b0
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Phi
- Pix2Struct
- PoolFormer
- Qwen2(Qwen1.5)
- RegNet
- ResNet
- Roberta
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class Qwen2OnnxConfig(LlamaOnnxConfig):
pass


class GemmaOnnxConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@
"gptj",
"imagegpt",
"llama",
"phi",
"mistral",
"phi",
"qwen2",
}


Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,14 @@ class TasksManager:
"text-classification",
onnx="OPTOnnxConfig",
),
"qwen2": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="Qwen2OnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def prepare_past_key_values(
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"gemma", "mistral", "llama"}:
elif self.model_type in {"mistral", "llama", "qwen2"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class NormalizedConfigManager:
"whisper": WhisperLikeNormalizedTextConfig,
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"qwen2": NormalizedTextConfig,
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
# "rembert": "google/rembert",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen2": "fxmarty/tiny-dummy-qwen2",
"regnet": "hf-internal-testing/tiny-random-RegNetModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,6 +2258,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"llama",
"mistral",
"mpt",
"qwen2",
]

FULL_GRID = {
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"perceiver_vision": "hf-internal-testing/tiny-random-vision_perceiver_conv",
"pix2struct": "fxmarty/pix2struct-tiny-random",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen2": "fxmarty/tiny-dummy-qwen2",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
Expand Down

0 comments on commit e6641b0

Please sign in to comment.