-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Considerable decrease in policy performance after PyTorch 2.5.0 update #228
Comments
Hi @alopezrivera , thanks for reporting this bug, as it may be critical.
|
Hey @alopezrivera, I could not reproduce your results on my end. I am running the quickstart notebook with both the previous version and the newer one, and get the same result. Output of
PS: note that due to several recent updates to PyTorch, there may be some incompatibilities with I installed the nightly versions with the following command: pip3 install torchrl-nightly |
@alopezrivera most recent TorchRL and Tensordict have been released, still cannot reproduce the bug. When you become available, please let us know how to reproduce the result! |
In that case it is highly likely that this had to do with either numerical precision settings or some operation conducted inside my space vehicle routing environment. I'll look into those two possibilities asap and try to construct a minimal environment that reproduces the issue. |
Good, also you may try testing it in another system. For example if the problem is reproducible in Google Colab, then it is likely an important issue in RL4CO. Otherwise it might be outside of our control |
So far I was able to reproduce this behavior on the 4 systems below (all on RunPod), which makes me think the environment is the culprit here
|
Important See possible fixes at the end of this message! I tried manually updating all dependencies, and it turns out that, on a 3090, the bug may be reproducible. In my case, the loss explodes during the Epoch 2 on this notebook 🤔 Here is my current env (note that you may install the dev version of
Possible fixesSo far I have managed to fix the bug (temporarily) in the following ways: 1. Reduce precision If precision is increased to "32", then there is no more bug (i.e. pass to the 2. Install a previous PyTorch version <2.4.0 If you install, say, with 3. Use another GPU Also suboptimal, as it seems newer GPUs are affected. Training in Google Colab did not yield this result. I strongly suspect it is due to precision issues and implementation of say SDPA implementations on certain devices and precision settings The above are still temporary fixes, and I am not 100% sure why they happen as the logic did not change - so the problem should be outside of RL4CO but somehow we may need to adapt. Most likely, it is due to a precision error - why that is, that's the question. Precision is handled by PyTorch Lightning, so one option is to check out for updates on their side. Another option is to dig deep into the SDPA and see by changing that to the manual implementation / FlashAttention repo whether the problem persists CC: @cbhua UPDATE 1 If the with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
trainer.fit(model) this "simple trick" seems to work well for me (at times only though - pretty weird), indicating the direction of SDPA / precision setting was indeed correct UPDATE 2 (narrowed down the issue and possible fix!) I think we finally found the main culprit. The issue appears to be in the
The easiest fix for affected devices is to simply change the - self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention
+ self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention_simple i.e. replacing the SDPA implementation with our own. This appears to solve the issue without changing too much code! Minor note: for single-query attention (as in AM), this appears to speed up the performance a little, while in multi-query (such as multi-start), it seems to be slightly slower, i.e., in POMO. Reasons why this happensThere can be a few. What I suspect is that there is some wrong conversion in PyTorch between fp32 -> fp16 and vice versa, for instance, for the Reproducing the problemThe snippet below should reproduce the issue in your case. Normally, the two implementations should be exactly the same: from torch.nn.functional import scaled_dot_product_attention
from rl4co.models.nn.attention import scaled_dot_product_attention_simple
# Make some random data
bs, n, d = 32, 100, 128
q = torch.rand(bs, n, d)
k = torch.rand(bs, n, d)
v = torch.rand(bs, n, d)
attn_mask = torch.rand(bs, n, n) < 0.1
# to float16
q = q.half().cuda()
k = k.half().cuda()
v = v.half().cuda()
attn_mask = attn_mask.cuda()
# Run the two implementations
with torch.amp.autocast("cuda"):
out_pytorch = scaled_dot_product_attention(q, k, v, attn_mask)
out_ours = scaled_dot_product_attention_simple(q, k, v, attn_mask)
# If the two outputs are not close, print the maximum difference
if not torch.allclose(out_ours, out_pytorch, atol=1e-3):
raise ValueError(f"Outputs are not close. Max diff: {torch.max(torch.abs(out_ours - out_pytorch))}") |
@alopezrivera in the new RL4CO version, you may replace the decoder's scaled dot product attention implementation (SDPA) by passing policy = AttentionModelPolicy(env_name=env.name, sdpa_fn_decoder="simple") |
Describe the bug
I have observed a considerable decrease in policy performance after the recent PyTorch 2.5.0 update. The decrease in performance replicates when training with A2C, REINFORCE and PPO.
Before: brown. After: purple. Same environment model, same random seeds.
To Reproduce
Install RL4CO and other dependencies using the following Conda
environment.yaml
:Previous result when creating the environment
Approximately 3 days ago this would've installed the following dependencies:
This environment can be replicated with the following
environment.yaml
:where
requirements.txt
must be stored in the same directory asenvironment.yaml
and contain:Current result when creating the environment
As of today it installs the following dependencies, including PyTorch 2.5.0:
Detailed list of dependencies
The following is a detailed list of all different dependencies between the environment created 3 days ago and the current one. I believe PyTorch 2.5.0 is the main culprit here.
System info
NVIDIA L40 system
pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
Reason and Possible fixes
No idea as to the reason. A temporary fix could be to lock the PyTorch version required by RL4CO to PyTorch 2.4.1.
Checklist
The text was updated successfully, but these errors were encountered: