From 34b43211d782c00da6fef778dbfaff69bbf3f115 Mon Sep 17 00:00:00 2001 From: mig-mfreitas <132093787+mig-mfreitas@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:07:58 +0100 Subject: [PATCH] Add YaRN and Dynamic-YaRN RoPE Scaling Methods (#30910) * Add YaRN and Dynamic-YaRN RoPE Scaling Methods YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida * Refactor YaRN implementation for LLaMA Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas * Refactor Tensor Building Logic for YaRN - Comply with the the tensor building logic introduced in #30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas * remove unwanted file --------- Co-authored-by: Miguel Almeida Co-authored-by: mig-mfreitas Co-authored-by: Joao Gante --- .../models/falcon/modeling_falcon.py | 1 - .../models/fuyu/configuration_fuyu.py | 1 - .../models/gpt_neox/configuration_gpt_neox.py | 1 - .../models/llama/configuration_llama.py | 58 ++++++++++-- .../models/llama/modeling_llama.py | 88 +++++++++++++++++++ .../models/olmo/configuration_olmo.py | 1 - src/transformers/models/olmo/modeling_olmo.py | 1 - .../persimmon/configuration_persimmon.py | 1 - .../models/phi/configuration_phi.py | 1 - .../models/stablelm/configuration_stablelm.py | 1 - tests/models/llama/test_modeling_llama.py | 23 ++++- 11 files changed, 162 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d1050d542a2f38..663582c8a72a83 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -283,7 +283,6 @@ def __init__(self, config: FalconConfig): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 - # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = FalconRotaryEmbedding( diff --git a/src/transformers/models/fuyu/configuration_fuyu.py b/src/transformers/models/fuyu/configuration_fuyu.py index ffcdd2b61750a6..03d2aecc02b6c9 100644 --- a/src/transformers/models/fuyu/configuration_fuyu.py +++ b/src/transformers/models/fuyu/configuration_fuyu.py @@ -188,7 +188,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 8e4c94692e0537..944dbb5e02f098 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -154,7 +154,6 @@ def __init__( "The hidden size is not divisble by the number of attention heads! Make sure to update them!" ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 1a059101e42492..843731eeffc8ee 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + For the `yarn` strategy, the dictionary may also contain the following fields: + `original_max_position_embeddings` (`int`, *optional*): + The original maximum sequence length. This is used to scale the RoPE embeddings. + `attention_factor` (`float`, *optional*): + The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio. + `beta_fast` (`float`, *optional*): + Parameter to set the boundary for extrapolation (only) in the linear ramp function. + `beta_slow` (`float`, *optional*): + Parameter to set the boundary for interpolation (only) in the linear ramp function. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -178,15 +187,52 @@ def _rope_scaling_validation(self): if self.rope_scaling is None: return - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2: raise ValueError( - "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + "`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, " + f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]: raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + if rope_scaling_type != "yarn": + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: + raise ValueError( + "`rope_scaling` with type " + f"{rope_scaling_type}" + " must be a dictionary with a maximum of six fields, `type`, `factor`," + "`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, " + f"got {self.rope_scaling}" + ) + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) + attention_factor = self.rope_scaling.get("attention_factor", None) + beta_fast = self.rope_scaling.get("beta_fast", None) + beta_slow = self.rope_scaling.get("beta_slow", None) + + if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): + raise ValueError( + f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" + ) + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + b_fast = beta_fast if beta_fast is not None else 32 + b_slow = beta_slow if beta_slow is not None else 1 + if b_fast < b_slow: + raise ValueError( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" + ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5c0c57f3effe86..b624a2d92d0970 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -132,6 +132,77 @@ def forward(self, x, position_ids): return cos, sin +class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1, + original_max_position_embeddings=2048, + attention_factor=None, + beta_fast=32, + beta_slow=1, + device=None, + ): + super().__init__(dim, max_position_embeddings, base, device, scaling_factor) + + self.original_max_position_embeddings = original_max_position_embeddings + self.attention_factor = attention_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + + if self.attention_factor is None: + # Recommended attention factor for LLaMA models. + # For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22. + self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0 + + self.compute_yarn_scaling(device) + + # Inverse dimension formula to find the dimension based on the number of rotations + def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dimension range bounds based on rotations + def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_mask(self, min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def forward(self, x, position_ids=None): + # Difference to the original RoPE: applies a scaling factor computed with + # the YaRN method (NTK-by-Parts + Attn Scaling) + # x: [bs, num_attention_heads, seq_len, head_size] + cos, sin = super().forward(x, position_ids) + cos = cos * self.mscale + sin = sin * self.mscale + return cos, sin + + def compute_yarn_scaling(self, device): + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs) + + low, high = self.find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device) + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.register_buffer("inv_freq", inv_freq) + # Get n-dimensional magnitude scaling corrected for interpolation + self.mscale = self.attention_factor + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -258,6 +329,15 @@ def _init_rope(self): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] + # Yarn parameters + kwargs = { + "dim": self.config.rope_scaling.get("original_max_position_embeddings", None), + "max_position_embeddings": self.config.rope_scaling.get("attention_factor", None), + "base": self.config.rope_scaling.get("beta_fast", None), + "scaling_factor": self.config.rope_scaling.get("beta_slow", None), + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, @@ -272,6 +352,14 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYarnScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index a25ccd8cc09def..77a3b18e364ecf 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -160,7 +160,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 4fd0c92686834b..59c9b3bf1b66a4 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -236,7 +236,6 @@ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self._init_rope() - # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = OlmoRotaryEmbedding( diff --git a/src/transformers/models/persimmon/configuration_persimmon.py b/src/transformers/models/persimmon/configuration_persimmon.py index b8e02256de808a..11f4c66d73e6b3 100644 --- a/src/transformers/models/persimmon/configuration_persimmon.py +++ b/src/transformers/models/persimmon/configuration_persimmon.py @@ -138,7 +138,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index d1e3464ee48271..e54d400ae6e72e 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -165,7 +165,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/stablelm/configuration_stablelm.py b/src/transformers/models/stablelm/configuration_stablelm.py index abea7483a67de6..c05ac9f036d62b 100644 --- a/src/transformers/models/stablelm/configuration_stablelm.py +++ b/src/transformers/models/stablelm/configuration_stablelm.py @@ -164,7 +164,6 @@ def __init__( **kwargs, ) - # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index e0311b7cea4a0e..de7eb7e44156c1 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -55,6 +55,7 @@ LlamaDynamicNTKScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding, + LlamaYarnScalingRotaryEmbedding, ) @@ -397,7 +398,7 @@ def test_llama_token_classification_model(self): def test_save_load_fast_init_from_base(self): pass - @parameterized.expand([("linear",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -491,6 +492,26 @@ def test_model_rope_scaling(self): torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + # Sanity check Yarn RoPE scaling + yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + @require_flash_attn @require_torch_gpu @require_bitsandbytes