Skip to content

Commit

Permalink
Remove redundant transposes for rope rotation (#807)
Browse files Browse the repository at this point in the history
* ..

* ..

* Update llmfoundry/models/layers/attention.py

Co-authored-by: Daniel King <[email protected]>

* ..

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
ShashankMosaicML and dakinggg authored Dec 20, 2023
1 parent 2ba9224 commit 289536b
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
import transformers
from einops import rearrange
from packaging import version
from torch import nn
Expand All @@ -34,6 +35,10 @@ def is_flash_v1_installed():
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


def is_transformers_version_gte(hf_version: str) -> bool:
return version.parse(transformers.__version__) >= version.parse(hf_version)


# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
Expand Down Expand Up @@ -627,14 +632,20 @@ def forward(
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if is_transformers_version_gte('4.36'):
query, key = apply_rotary_pos_emb(query,
key,
cos,
sin,
offset_info,
unsqueeze_dim=2)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
Expand Down

0 comments on commit 289536b

Please sign in to comment.