Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing the extra LlamaRotaryEmbedding import #1394

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding
from transformers.models.llama.modeling_llama import (
LlamaConfig,
LlamaRotaryEmbedding,
)

from llmfoundry.layers_registry import norms, param_init_fns
from llmfoundry.models.layers.attention import (
Expand Down Expand Up @@ -166,7 +166,7 @@ def gen_rotary_embedding(
num_attention_heads=n_heads,
)
if rope_hf_config['type'] == 'no_scaling':
return HFRotaryEmbeddingFoundry(config=partial_llama_config)
return LlamaRotaryEmbeddingFoundry(config=partial_llama_config)
elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}:
return LlamaRotaryEmbedding(config=partial_llama_config)
raise ValueError('rope_impl needs to be either dail or hf')
Expand Down Expand Up @@ -341,7 +341,7 @@ def apply_sequence_id(
return attn_bias


class HFRotaryEmbeddingFoundry(HFRotaryEmbedding):
class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding):

@torch.no_grad()
def forward(
Expand Down
9 changes: 4 additions & 5 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
Expand All @@ -48,7 +47,7 @@
)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel
from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry
from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry
from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
Expand Down Expand Up @@ -2924,15 +2923,15 @@ def test_hf_rotary_child_class_builds():
list(range(max_seq_len)),
] * bsz)

rot_emb_mp = HFRotaryEmbeddingFoundry(
rot_emb_mp = LlamaRotaryEmbeddingFoundry(
rope_head_dim,
max_seq_len,
rope_theta,
device='cpu',
)
cos_mp, sin_mp = rot_emb_mp(value, position_ids)

rot_emb = HFRotaryEmbedding(
rot_emb = LlamaRotaryEmbedding(
rope_head_dim,
max_seq_len,
rope_theta,
Expand Down
Loading