Skip to content

Commit

Permalink
Adding a child class of hf's rotary embedding to make hf generate wor…
Browse files Browse the repository at this point in the history
…k on multiple gpus. (#1334)

* ..

* adding comment

* improving test

* lint

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Daniel King <[email protected]>

* addressing comments

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
ShashankMosaicML and dakinggg authored Jul 3, 2024
1 parent 1641b21 commit 73f267c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
16 changes: 15 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def gen_rotary_embedding(
)
elif rope_impl == 'hf':
if rope_hf_config['type'] == 'no_scaling':
return HFRotaryEmbedding(
return HFRotaryEmbeddingFoundry(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
Expand Down Expand Up @@ -306,6 +306,20 @@ def apply_sequence_id(
return attn_bias


class HFRotaryEmbeddingFoundry(HFRotaryEmbedding):

@torch.no_grad()
def forward(
self,
x: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# In this subclass, we move `inv_freq` to same device as position_ids. This operation should be a no-op during training.
# This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1334#issue-2387337525
self.inv_freq = self.inv_freq.to(position_ids.device)
return super().forward(x=x, position_ids=position_ids)


class MPTPreTrainedModel(PreTrainedModel):
config_class = MPTConfig
base_model_prefix = 'model'
Expand Down
34 changes: 34 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
Expand All @@ -46,6 +48,7 @@
)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel
from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
Expand Down Expand Up @@ -2908,3 +2911,34 @@ def _validate_helper(b_idx: int) -> int:
'The relative index of kv layer to reuse, override_attn_config\[\"reuse_kv_layer_idx\"\]=0, should be negative\.', # type: ignore
):
_validate_helper(b_idx=2)


def test_hf_rotary_child_class_builds():
rope_head_dim = 32
num_heads = 4
max_seq_len = 128
rope_theta = 10000
bsz = 4
value = torch.rand([bsz, num_heads, max_seq_len, rope_head_dim])
position_ids = torch.Tensor([
list(range(max_seq_len)),
] * bsz)

rot_emb_mp = HFRotaryEmbeddingFoundry(
rope_head_dim,
max_seq_len,
rope_theta,
device='cpu',
)
cos_mp, sin_mp = rot_emb_mp(value, position_ids)

rot_emb = HFRotaryEmbedding(
rope_head_dim,
max_seq_len,
rope_theta,
device='cpu',
)
cos, sin = rot_emb(value, position_ids)

assert torch.all(cos == cos_mp)
assert torch.all(sin == sin_mp)

0 comments on commit 73f267c

Please sign in to comment.