Skip to content

Commit

Permalink
fix bugs in KTO implementation (#1380)
Browse files Browse the repository at this point in the history
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: lewtun <[email protected]>
  • Loading branch information
3 people authored Feb 29, 2024
1 parent b32656f commit 14e0d78
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 99 deletions.
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: text_environments
title: Text Environments
- local: kto_trainer
title: KTO Trainer
title: API
- sections:
- local: example_overview
Expand Down
9 changes: 7 additions & 2 deletions docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# KTO Trainer

TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for training language models from unpaired preference data, as described in the [report](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
For a full example have a look at [`examples/scripts/kto.py`].

Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.

## Expected dataset format

Expand Down Expand Up @@ -44,7 +48,8 @@ kto_dataset_dict = {
}
```

where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired or undesired. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.

## Expected model format
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
Expand Down
53 changes: 38 additions & 15 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,44 @@
# limitations under the License.

"""
Run the KTO training script with the following command with some example arguments:
Run the KTO training script with the following command with some example arguments.
In general, the optimal configuration for KTO will be similar to that of DPO:
# regular:
python examples/scripts/kto.py \
--model_name_or_path "gpt2" \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--learning_rate 1e-4 \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--report_to "wandb" \
--gradient_checkpointing True \
--output_dir="./test" \
--use_peft True \
--lora_r 64 \
--lora_alpha 16 \
--evaluation_strategy "steps" \
--logging_first_step True \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500
--eval_steps 500 \
--output_dir="kto_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/kto.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

from dataclasses import dataclass, field
Expand All @@ -57,7 +77,10 @@ def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"

if search_term_idx == -1:
raise ValueError(f"Prompt and response does not contain '{search_term}'")

return prompt_and_response[: search_term_idx + len(search_term)]


Expand Down
Loading

0 comments on commit 14e0d78

Please sign in to comment.