Skip to content

Commit

Permalink
MixtralFlashAttention2: put "plus 1" inside parentheses when calculat…
Browse files Browse the repository at this point in the history
…ing rotary_seq_len, allowing None position_ids input. (#31500)

* Mixtral: remove unnecessary plus 1 when calculating rotary_seq_len, allowing position_ids=None (no auto position_ids generation could be unsafe)

* fix typo [:-1] to [:, -1]

* to meet formatting requirement

* to meet formatting requirement

* remove white space

* MixtralFlashAttention2: put "+ 1" inside parentheses when calculating rotary_seq_len, allowing None position_ids input. Fix format/style issue.

* propagate to startcoder2, phi3, mixtral and qwen2

* update qwen2_moe
  • Loading branch information
xenshinu authored Aug 3, 2024
1 parent 7c31d05 commit 621fb3c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,10 @@ def forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,11 @@ def forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,10 @@ def forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,10 @@ def forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,10 @@ def forward(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down

0 comments on commit 621fb3c

Please sign in to comment.