-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KTO training produces NaN rewards #1447
Comments
cc also @kashif |
@claralp depending on the batch-size it could be some of the metrics are nan, this should not effect the training etc. and special attention has been paid to make sure the loss etc. is robust to these nans when doing back-prop. |
@claralp i do not think nans in a dict should cause this to crash... do you have some crash back-traces? |
@kashif there are no errors or warnings in the stdout/stderr, it just stops at some point after the nan rewards appear, so I cannot provide a stack trace here.
lifecycler log shows only a Preemption signal:
|
I think this is could be the "normal" low-prioity Azure preemption? :-( |
Important note here: The crash only appears after the training shows nan values. Otherwise it doesn't.
Could there be anything wrong with the hyperparameter choice, @kashif ? |
@claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans? also does this happen if you try locally outside of the azure |
The output below is from a test with very unbalanced data, namely 2k desired completions and 10k undesired ones.
|
batch size is 8 and gradient accumulation steps is 2 as in the config above
currently checking this |
Within the training with KTO Trainer I occasionally experience
nan
values as rewards.I am running the training as a job on Ms Azure with one GPU (NVIDIA A100 80GB PCIe).
Ultimately these issues cause my Azure job to crash and retry...
The log output I get from the KTOTrainer:
my pip freeze:
the training script I use:
the call arguments
Maybe @lewtun can help
The text was updated successfully, but these errors were encountered: