-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add support for DPO #556
base: main
Are you sure you want to change the base?
Conversation
9c57623
to
f57ae81
Compare
f57ae81
to
deb71c1
Compare
The WANDB job : https://wandb.ai/sharma-sandeepch/trlx/runs/f7ym4m9y?workspace=user-sharma-sandeepch (updated the link to point to a run with a larger batch size) |
@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise? |
Thank you so much for the great PR. We will review this PR asap! |
Thank you so much @PhungVanDuy for reviewing 🙏 Yes it's the same wandb run i shared above. Here you go : https://wandb.ai/sharma-sandeepch/trlx/runs/f7ym4m9y?workspace=user-sharma-sandeepch |
from_fn = AutoModelForCausalLM.from_pretrained | ||
if issubclass(type(config.model.model_path), PretrainedConfig): | ||
from_fn = AutoModelForCausalLM.from_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AutoModelForSeq2SeqLM
support is missing here
Any update? |
@PhungVanDuy were you able to review this? |
I saw your wandb but actually the chart quite messup, seems reward/accuracies and reward/margin not clearly increase. I guess because you used gpt2 instead of an SFT model on HH to do DPO. Can you use this SFT model and this preference dataset to train with this branch? |
That's indeed what I did for a quick iteration and because I was limited on the compute i had. I will run it with the mistral-7b on the ultrafeedback dataset and get back asap. |
6404f83
to
506fbbd
Compare
506fbbd
to
6d63004
Compare
@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with Note :
|
Thank you for your information, I will use SFT-beta, to check this. Let me help you to run on my cluster. |
@sandeepchittilla can you add my discord with the handle: |
I'm excited about DPO support and I hope it'll be added soon! |
Closes #504
This PR adds Direct Policy Optimization as introduced in https://arxiv.org/abs/2305.18290
Loss calculation and concatenated forward pass implementations are adapted from the original TRL library