-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: device/type-invariant RoPE sin/cos computation, eager attention matches original implementation #28837
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -505,6 +506,120 @@ def test_eager_matches_sdpa_generate(self): | |||
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | |||
self.assertTrue(torch.allclose(res_eager, res_sdpa)) | |||
|
|||
@require_torch_gpu | |||
def test_rope_cast_strategy_invariant(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test fails on main
because inv_freq
was being casted with .to()
) | ||
|
||
@require_torch_gpu | ||
def test_rope_initialization_invariant(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test fails on main
, as initialization is device-dependent there
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
This PR fixes the following problems, all related to RoPE:
.from_pretrained(..., torch_dtype=...)
or.to(dtype=...)
would produce different sin/cos tensors at recomputation time. The underlying cause wasinv_freq
being a buffer, which means it was subject to buffer manipulation (like a.to()
operation in the wrapping module). Note that the original repo assumed it was always atorch.float32
tensor. In some models, there was a visible performance degradation when doing inference withseq_len > max_position_embeddings
(see here);inv_freq
tensor was being loaded from the state dict, due to a previous version of the code where it was a persistent buffer;a. Smaller modeling performance differences across devices, as CPUs are ubiquitous (as opposed to accelerators, which may change);
b. Prevention of loss spikes at train time, possibly due to the more accurate sin/cos computation (see this comment and the whole issue);
c. Slightly slower throughput when recomputing the sin/cos tensors, i.e. when going beyond
self.max_seq_len_cached
.See additional data and experiments below for the impact of this PR. Most of the diff in this PR is tests, to ensure we don't regress 🤗
Suggested review order:
(Other RoPE models will follow in a future PR)
Related GH issues
Fixes #28685
Fixes #25681
Fixes #28596
Fixes #27179
Should fix/help microsoft/DeepSpeed#4932
Additional data and experiments
Perlplexity, memory, and latency results before/after this PR
NOTE: using the
.to()
casting method. Thetorch_dtype
sees no differences, asinv_freq
is not casted.Llama 2 -- very little ppl differences
Dtype:
bfloat16
(ignore the vram -- the latest commit has the same GPU memory footprint as
main
)Dtype:
float16
(ignore the vram -- the latest commit has the same GPU memory footprint as
main
)TinyLlama -- visible ppl upgrade
Dtype:
bfloat16
(ignore the vram -- the latest commit has the same GPU memory footprint as
main
)Dtype:
float16
(ignore the vram -- the latest commit has the same GPU memory footprint as
main
)How sensible is the sin/cos creation to the device placement?
Consider the following script:
On
main
, before this PR, we can see differences as large as ~1e-3
regardless ofTEST_DTYPE
(even intorch.float64
!). After this PR, the difference is0.0
.Original Llama codebase vs our codebase after this PR?
Key takeaways:
👉 sin/cos are created on the available device (and not on CPU)
👉 sin/cos are not only kept in FP32, but also applied in FP32!
Consider the following script, which compares this hugging face's implementation against meta's repo
On
main
+ GPU + FP16, before this PR, we can see sin/cos and logits differences as large as2e-4
and6e-2
(respectively). After this PR, the difference is0.0
.