Skip to content

Commit

Permalink
Support rope scaling (#1391)
Browse files Browse the repository at this point in the history
* support rope scaling

* use rope scaling

* update to use rope config

* update config args

* use allowlist for config to enforce hygeine

* allow llama3 rope config

* add unit test

* documented allowed llama config keys

* Update llmfoundry/models/mpt/modeling_mpt.py

* Address comments 1

* Apply suggestions from code review

Co-authored-by: Daniel King <[email protected]>

* Apply suggestions from code review

Co-authored-by: Daniel King <[email protected]>

* use same codepath for all the hf rotary embeddings

* fix

* update

* test WIP but fix get/pop

* change the thing being popped

* give up on testing hf

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
milocress and dakinggg authored Jul 24, 2024
1 parent af37ea0 commit 70586c4
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 35 deletions.
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _validate_config(self) -> None:
'no_scaling',
'linear',
'dynamic',
'llama3',
]:
raise ValueError(
'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".',
Expand Down
98 changes: 67 additions & 31 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import \
LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
from transformers.models.llama.modeling_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

Expand Down Expand Up @@ -88,14 +86,62 @@
log = logging.getLogger(__name__)


class InvalidConfigAccessError(KeyError):
pass


_ALLOWED_LLAMA_CONFIG_KEYS = {
# These are the only config keys that are set and are safe to read from
'rope_scaling',
'rope_theta',
'max_position_embeddings',
'hidden_size',
'num_attention_heads',

# Not set but llama modeling code tries to read this attribute
'partial_rotary_factor',

# Benign transformers attributes needed for __init__
'_get_generation_defaults',
'label2id',
'id2label',
'torch_dtype',
'problem_type',
'__class__',
}


class PartialLlamaConfig(LlamaConfig):
"""Holds the rope config for Llama models and throws.
an `InvalidConfigAccessError` if any other config elements are read. This
class is necessary because the `LlamaRotaryEmbedding` class takes a full
`LlamaConfig` now instead of the old keyword arguments.
"""

def __getattribute__(self, key: str):
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getattribute__(key)

def __getitem__(self, key: str):
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
raise InvalidConfigAccessError(key)

return super().__getitem__(key)


def gen_rotary_embedding(
rope_head_dim: int,
rope_impl: str,
rope_theta: int,
rope_dail_config: dict,
rope_hf_config: dict,
max_seq_len: int,
d_model: int,
n_heads: int,
):
rope_head_dim = d_model // n_heads
if rope_impl == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
Expand All @@ -108,32 +154,21 @@ def gen_rotary_embedding(
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_impl == 'hf':
llama_rope_config = {**rope_hf_config}
llama_rope_config['rope_type'] = llama_rope_config.pop('type')
if llama_rope_config['rope_type'] == 'no_scaling':
llama_rope_config['rope_type'] = 'default'
partial_llama_config = PartialLlamaConfig(
rope_scaling=llama_rope_config,
rope_theta=rope_theta,
max_position_embeddings=max_seq_len,
hidden_size=d_model,
num_attention_heads=n_heads,
)
if rope_hf_config['type'] == 'no_scaling':
return HFRotaryEmbeddingFoundry(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_hf_config['type'] == 'linear':
return HFLinearScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
scaling_factor=rope_hf_config['factor'],
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif rope_hf_config['type'] == 'dynamic':
return HFDynamicNTKScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=max_seq_len,
base=rope_theta,
scaling_factor=rope_hf_config['factor'],
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
return HFRotaryEmbeddingFoundry(config=partial_llama_config)
elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}:
return LlamaRotaryEmbedding(config=partial_llama_config)
raise ValueError('rope_impl needs to be either dail or hf')


Expand Down Expand Up @@ -399,12 +434,13 @@ def __init__(self, config: MPTConfig):
if self.rope:
self.rope_impl = config.attn_config['rope_impl']
self.rotary_embedding = gen_rotary_embedding(
rope_head_dim=config.d_model // config.n_heads,
rope_impl=self.rope_impl,
rope_theta=config.attn_config['rope_theta'],
rope_dail_config=config.attn_config['rope_dail_config'],
rope_hf_config=config.attn_config['rope_hf_config'],
max_seq_len=self.config.max_seq_len,
d_model=config.d_model,
n_heads=config.n_heads,
)

if config.init_device != 'meta':
Expand Down
6 changes: 4 additions & 2 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down Expand Up @@ -664,12 +665,13 @@ def gen_bias(attn_impl: str):
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
rope_head_dim=cfg['d_model'] // cfg['n_heads'],
rope_impl=pos_emb_config['rope_impl'],
rope_theta=pos_emb_config['rope_theta'],
rope_dail_config=pos_emb_config.get('rope_dail_config', {}),
rope_hf_config=pos_emb_config.get('rope_hf_config', {}),
max_seq_len=s,
d_model=cfg['d_model'],
n_heads=cfg['n_heads'],
).to(device)
pos = torch.arange(s).unsqueeze(0).to(device=device)
# adjust the position indices to account for padding tokens
Expand Down
6 changes: 4 additions & 2 deletions tests/models/test_rope_dail_vs_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

dail_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=dail_rope_config['rope_impl'],
rope_theta=dail_rope_config['rope_theta'],
rope_dail_config=dail_rope_config['rope_dail_config'],
rope_hf_config={},
max_seq_len=seq_len,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to('cuda')
dail_rope_w_meta_info = {
'impl': 'dail',
Expand All @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
}

hf_rope = gen_rotary_embedding(
rope_head_dim=cfg.d_model // cfg.n_heads,
rope_impl=hf_rope_config['rope_impl'],
rope_theta=hf_rope_config['rope_theta'],
rope_dail_config={},
rope_hf_config=hf_rope_config['rope_hf_config'],
max_seq_len=seq_len,
d_model=cfg.d_model,
n_heads=cfg.n_heads,
).to('cuda')
pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda')
# adjust the position indices to account for padding tokens
Expand Down
35 changes: 35 additions & 0 deletions tests/models/test_rope_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding

rope_config = {
'rope_theta': 500000.0,
'rope_impl': 'hf',
'rope_hf_config': {
'factor': 8.0,
'low_freq_factor': 1.0,
'high_freq_factor': 4.0,
'original_max_position_embeddings': 8192,
'type': 'llama3',
},
}

rope_dail_config = {}


def test_rope_scaling():
d_model = 128
n_heads = 32
max_seq_len = 65536

embedding = gen_rotary_embedding(
d_model=d_model,
n_heads=n_heads,
rope_dail_config=rope_dail_config,
max_seq_len=max_seq_len,
**rope_config,
)

assert isinstance(embedding, LlamaRotaryEmbedding)

0 comments on commit 70586c4

Please sign in to comment.