Skip to content

Commit

Permalink
Add YaRN and Dynamic-YaRN RoPE Scaling Methods (huggingface#30910)
Browse files Browse the repository at this point in the history
* Add YaRN and Dynamic-YaRN RoPE Scaling Methods

YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Co-authored-by: Miguel Almeida <[email protected]>

* Refactor YaRN implementation for LLaMA

Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.

This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
  from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies

Co-authored-by: Miguel Monte e Freitas <[email protected]>

* Refactor Tensor Building Logic for YaRN

- Comply with the the tensor building logic introduced in huggingface#30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment

Co-authored-by: mig-mfreitas <[email protected]>

* remove unwanted file

---------

Co-authored-by: Miguel Almeida <[email protected]>
Co-authored-by: mig-mfreitas <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
4 people authored and zucchini-nlp committed Jul 24, 2024
1 parent c1c7547 commit c51802e
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 15 deletions.
1 change: 0 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def __init__(self, config: FalconConfig):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = FalconRotaryEmbedding(
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/fuyu/configuration_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
58 changes: 52 additions & 6 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
For the `yarn` strategy, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
`attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio.
`beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
`beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
Expand Down Expand Up @@ -178,15 +187,52 @@ def _rope_scaling_validation(self):
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

if rope_scaling_type != "yarn":
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)

if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")

b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)
88 changes: 88 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,77 @@ def forward(self, x, position_ids):
return cos, sin


class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)

self.original_max_position_embeddings = original_max_position_embeddings
self.attention_factor = attention_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow

if self.attention_factor is None:
# Recommended attention factor for LLaMA models.
# For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22.
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0

self.compute_yarn_scaling(device)

# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

# Find dimension range bounds based on rotations
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)

def linear_ramp_mask(self, min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

def forward(self, x, position_ids=None):
# Difference to the original RoPE: applies a scaling factor computed with
# the YaRN method (NTK-by-Parts + Attn Scaling)
# x: [bs, num_attention_heads, seq_len, head_size]
cos, sin = super().forward(x, position_ids)
cos = cos * self.mscale
sin = sin * self.mscale
return cos, sin

def compute_yarn_scaling(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)

low, high = self.find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

self.register_buffer("inv_freq", inv_freq)
# Get n-dimensional magnitude scaling corrected for interpolation
self.mscale = self.attention_factor


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -258,6 +329,15 @@ def _init_rope(self):
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
# Yarn parameters
kwargs = {
"dim": self.config.rope_scaling.get("original_max_position_embeddings", None),
"max_position_embeddings": self.config.rope_scaling.get("attention_factor", None),
"base": self.config.rope_scaling.get("beta_fast", None),
"scaling_factor": self.config.rope_scaling.get("beta_slow", None),
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
Expand All @@ -272,6 +352,14 @@ def _init_rope(self):
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = OlmoRotaryEmbedding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/phi/configuration_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/stablelm/configuration_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
23 changes: 22 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
)


Expand Down Expand Up @@ -397,7 +398,7 @@ def test_llama_token_classification_model(self):
def test_save_load_fast_init_from_base(self):
pass

@parameterized.expand([("linear",), ("dynamic",)])
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
Expand Down Expand Up @@ -491,6 +492,26 @@ def test_model_rope_scaling(self):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())

# Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down

0 comments on commit c51802e

Please sign in to comment.