diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index bee0269ff82e9a..61315e18d41245 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -324,18 +324,17 @@ def _compute_llama3_parameters( low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in inv_freq: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - new_freqs.append((1 - smooth) * freq / factor + smooth * freq) - inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new) + return inv_freq, attention_factor @@ -501,7 +500,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") - if high_freq_factor < low_freq_factor: + if high_freq_factor <= low_freq_factor: logger.warning( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}"