Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 6, 2023
1 parent c3b584a commit 4df6252
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def prepare_inputs_for_merged(
batch_size = input_ids.shape[0]

if self.normalized_config.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_attention_heads
else:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

dtype = constructor.float16 if self.use_fp16 else constructor.float32
Expand Down Expand Up @@ -282,9 +282,9 @@ def compute_past_key_values_output_shapes(
"""
batch_size = input_ids.size(0)
if self.normalized_config.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_attention_heads
else:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

sequence_length = input_ids.size(1)
Expand Down

0 comments on commit 4df6252

Please sign in to comment.