Skip to content

Commit

Permalink
Llama 3.1: replace for loop by tensor ops at inv_freq initialization (h…
Browse files Browse the repository at this point in the history
…uggingface#32244)

* replace for loop by tensor ops

* rm assert; readability
  • Loading branch information
gante authored Jul 27, 2024
1 parent 8da9068 commit 44f6fdd
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,17 @@ def _compute_llama3_parameters(

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)

wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_new = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_new / factor + smooth_factor * inv_freq_new
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_new = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_new)

return inv_freq, attention_factor


Expand Down Expand Up @@ -501,7 +500,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
if high_freq_factor <= low_freq_factor:
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
Expand Down

0 comments on commit 44f6fdd

Please sign in to comment.