From 1b566831d3b1813385b0df3d84c91125638a3eb5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 23 Aug 2024 13:10:55 +0200 Subject: [PATCH] fix bloom modeling --- optimum/exporters/onnx/model_configs.py | 37 +++++++++++++------------ optimum/onnxruntime/modeling_decoder.py | 9 +++--- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index dc578facaec..d4b15b2968b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -341,25 +341,28 @@ class BloomOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): - if direction not in ["inputs", "outputs"]: - raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - - if direction == "inputs": - decoder_sequence_name = "past_sequence_length" - name = "past_key_values" + if check_if_transformers_greater("4.44"): + super().add_past_key_values(inputs_or_outputs, direction) else: - decoder_sequence_name = "past_sequence_length + 1" - name = "present" + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - for i in range(self._normalized_config.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = { - 0: "batch_size x num_heads", - 2: decoder_sequence_name, - } - inputs_or_outputs[f"{name}.{i}.value"] = { - 0: "batch_size x num_heads", - 1: decoder_sequence_name, - } + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = { + 0: "batch_size x num_heads", + 2: decoder_sequence_name, + } + inputs_or_outputs[f"{name}.{i}.value"] = { + 0: "batch_size x num_heads", + 1: decoder_sequence_name, + } class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 6a0dcbba2f0..612017a8bb7 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -336,8 +336,7 @@ def prepare_past_key_values( dtype = constructor.float16 if self.use_fp16 else constructor.float32 # TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY. - # "1" is the dummy sequence length - if self.model_type == "bloom": + if self.model_type == "bloom" and not check_if_transformers_greater("4.44"): shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0) key = constructor.zeros(shape_key, dtype=dtype) @@ -354,9 +353,9 @@ def prepare_past_key_values( for name, value in zip(self.key_value_output_names, past_key_values): shape = [*value.shape] index = 1 if "value" in name else 2 - shape[index] += sequence_length pkv_output_shape[name] = shape + elif self.model_type == "gpt_bigcode": # GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor. shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) @@ -371,9 +370,9 @@ def prepare_past_key_values( shape = [*value.shape] shape[1] += sequence_length pkv_output_shape[name] = shape + else: num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads - shape = (batch_size, num_key_value_heads, 0, embed_size_per_head) key_or_value = constructor.zeros(shape, dtype=dtype) @@ -568,7 +567,7 @@ def _from_pretrained( provider_options=provider_options, ) - if config.model_type == "bloom": + if config.model_type == "bloom" and not check_if_transformers_greater("4.44"): init_cls = ORTBloomForCausalLM elif config.model_type == "falcon": init_cls = ORTFalconForCausalLM