Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for DPO #556

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

sandeepchittilla
Copy link

@sandeepchittilla sandeepchittilla commented Sep 7, 2023

Closes #504

This PR adds Direct Policy Optimization as introduced in https://arxiv.org/abs/2305.18290

Loss calculation and concatenated forward pass implementations are adapted from the original TRL library

@sandeepchittilla
Copy link
Author

sandeepchittilla commented Sep 13, 2023

The WANDB job : https://wandb.ai/sharma-sandeepch/trlx/runs/f7ym4m9y?workspace=user-sharma-sandeepch

(updated the link to point to a run with a larger batch size)

@sandeepchittilla sandeepchittilla marked this pull request as ready for review September 13, 2023 15:02
@sandeepchittilla
Copy link
Author

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

@PhungVanDuy
Copy link
Collaborator

PhungVanDuy commented Sep 14, 2023

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap!
Can you share any wandb that you ran?

@sandeepchittilla
Copy link
Author

sandeepchittilla commented Sep 14, 2023

Thank you so much @PhungVanDuy for reviewing 🙏

Yes it's the same wandb run i shared above. Here you go : https://wandb.ai/sharma-sandeepch/trlx/runs/f7ym4m9y?workspace=user-sharma-sandeepch

Comment on lines +60 to +62
from_fn = AutoModelForCausalLM.from_pretrained
if issubclass(type(config.model.model_path), PretrainedConfig):
from_fn = AutoModelForCausalLM.from_config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoModelForSeq2SeqLM support is missing here

@sandeepchittilla sandeepchittilla marked this pull request as draft September 27, 2023 08:44
@sandeepchittilla sandeepchittilla marked this pull request as ready for review September 27, 2023 08:44
@LouisCastricato
Copy link
Contributor

Any update?

@sandeepchittilla
Copy link
Author

sandeepchittilla commented Nov 13, 2023

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

@PhungVanDuy were you able to review this?

@PhungVanDuy
Copy link
Collaborator

PhungVanDuy commented Nov 13, 2023

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

@PhungVanDuy were you able to review this?

I saw your wandb but actually the chart quite messup, seems reward/accuracies and reward/margin not clearly increase. I guess because you used gpt2 instead of an SFT model on HH to do DPO. Can you use this SFT model and this preference dataset to train with this branch?

@sandeepchittilla
Copy link
Author

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

@PhungVanDuy were you able to review this?

I saw your wandb but actually the chart quite messup, seems reward/accuracies and reward/margin not clearly increase. I guess because you used gpt2 instead of an SFT model on HH to do DPO. Can you use this SFT model and this preference dataset to train with this branch?

That's indeed what I did for a quick iteration and because I was limited on the compute i had. I will run it with the mistral-7b on the ultrafeedback dataset and get back asap.

@sandeepchittilla sandeepchittilla force-pushed the 504-dpo-trainer branch 2 times, most recently from 6404f83 to 506fbbd Compare November 20, 2023 16:15
@sandeepchittilla
Copy link
Author

sandeepchittilla commented Nov 24, 2023

@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with mistral-7b-sft-beta on the ultrafeedback_binarized dataset : https://wandb.ai/sharma-sandeepch/trlx/runs/kfpmeonf?workspace=user-sharma-sandeepch

Note :

  • Ultrafeedback is a challenging dataset for DPO because the rejected responses are randomly sampled
  • I have not done a sft pass on the data so we see some fluctuating plots.
  • I have limited memory and GPUs are not the best in class so I've chosen only a subset of test_prefs for evaluation

@PhungVanDuy
Copy link
Collaborator

PhungVanDuy commented Nov 26, 2023

@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with mistral-7b-sft-beta on the ultrafeedback_binarized dataset : https://wandb.ai/sharma-sandeepch/trlx/runs/kfpmeonf?workspace=user-sharma-sandeepch

Note :

  • Ultrafeedback is a challenging dataset for DPO because the rejected responses are randomly sampled
  • I have not done a sft pass on the data so we see some fluctuating plots.
  • I have limited memory and GPUs are not the best in class so I've chosen only a subset of test_prefs for evaluation

Thank you for your information, I will use SFT-beta, to check this. Let me help you to run on my cluster.

@PhungVanDuy
Copy link
Collaborator

@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with mistral-7b-sft-beta on the ultrafeedback_binarized dataset : https://wandb.ai/sharma-sandeepch/trlx/runs/kfpmeonf?workspace=user-sharma-sandeepch
Note :

  • Ultrafeedback is a challenging dataset for DPO because the rejected responses are randomly sampled
  • I have not done a sft pass on the data so we see some fluctuating plots.
  • I have limited memory and GPUs are not the best in class so I've chosen only a subset of test_prefs for evaluation

Thank you for your information, I will use SFT-beta, to check this. Let me help you to run on my cluster.

@sandeepchittilla can you add my discord with the handle: duyphung.ai, it will be easier to discuss on this. Thank you so much.

@StellaAthena
Copy link
Contributor

I'm excited about DPO support and I hope it'll be added soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Direct Policy Optimization
5 participants