Skip to content

Commit

Permalink
fix bloom modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Aug 23, 2024
1 parent e6d7c13 commit 1b56683
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
37 changes: 20 additions & 17 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,25 +341,28 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
if check_if_transformers_greater("4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
9 changes: 4 additions & 5 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def prepare_past_key_values(
dtype = constructor.float16 if self.use_fp16 else constructor.float32

# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
# "1" is the dummy sequence length
if self.model_type == "bloom":
if self.model_type == "bloom" and not check_if_transformers_greater("4.44"):
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0)
key = constructor.zeros(shape_key, dtype=dtype)
Expand All @@ -354,9 +353,9 @@ def prepare_past_key_values(
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
index = 1 if "value" in name else 2

shape[index] += sequence_length
pkv_output_shape[name] = shape

elif self.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
Expand All @@ -371,9 +370,9 @@ def prepare_past_key_values(
shape = [*value.shape]
shape[1] += sequence_length
pkv_output_shape[name] = shape

else:
num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads

shape = (batch_size, num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

Expand Down Expand Up @@ -568,7 +567,7 @@ def _from_pretrained(
provider_options=provider_options,
)

if config.model_type == "bloom":
if config.model_type == "bloom" and not check_if_transformers_greater("4.44"):
init_cls = ORTBloomForCausalLM
elif config.model_type == "falcon":
init_cls = ORTFalconForCausalLM
Expand Down

0 comments on commit 1b56683

Please sign in to comment.