generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial orpo skeleton * typos * calculate orpo loss * fix class name * fix tests * fix typo * Update docs/source/orpo_trainer.md Co-authored-by: lewtun <[email protected]> * Update docs/source/orpo_trainer.md Co-authored-by: lewtun <[email protected]> * Update docs/source/orpo_trainer.md Co-authored-by: lewtun <[email protected]> * rename max_target_length * Update examples/scripts/orpo.py Co-authored-by: lewtun <[email protected]> * Update examples/scripts/orpo.py Co-authored-by: lewtun <[email protected]> * Update examples/scripts/orpo.py Co-authored-by: lewtun <[email protected]> * more docs * log log_odds_ratio and log_odds * average_log_prob as per paper * added logging section * add nll_loss * fix typo * more verbose * rename log_odds to log_odds_chosen * allow datasets to be loaded * remove dup debug arg * tokenizer exists * fix typo * use trl-internal-testing/hh-rlhf-trl-style dataset * formatting * add missing imports * fix output dir name * Update examples/scripts/orpo.py Co-authored-by: lewtun <[email protected]> * move dataset_num_proc to configs * Update trl/trainer/orpo_config.py Co-authored-by: Alvaro Bartolome <[email protected]> * Update trl/trainer/orpo_trainer.py Co-authored-by: Alvaro Bartolome <[email protected]> * add ORPOTrainer to readme * fix typo --------- Co-authored-by: lewtun <[email protected]> Co-authored-by: Alvaro Bartolome <[email protected]>
- Loading branch information
1 parent
d1df79f
commit 2ce8e45
Showing
9 changed files
with
1,424 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# ORPO Trainer | ||
|
||
[Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT. | ||
|
||
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory. | ||
|
||
The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo). | ||
|
||
## Expected dataset format | ||
|
||
The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows: | ||
|
||
- `prompt` | ||
- `chosen` | ||
- `rejected` | ||
|
||
for example: | ||
|
||
```py | ||
orpo_dataset_dict = { | ||
"prompt": [ | ||
"hello", | ||
"how are you", | ||
"What is your name?", | ||
"What is your name?", | ||
"Which is the best programming language?", | ||
"Which is the best programming language?", | ||
"Which is the best programming language?", | ||
], | ||
"chosen": [ | ||
"hi nice to meet you", | ||
"I am fine", | ||
"My name is Mary", | ||
"My name is Mary", | ||
"Python", | ||
"Python", | ||
"Java", | ||
], | ||
"rejected": [ | ||
"leave me alone", | ||
"I am not fine", | ||
"Whats it to you?", | ||
"I dont have a name", | ||
"Javascript", | ||
"C++", | ||
"C++", | ||
], | ||
} | ||
``` | ||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. | ||
|
||
## Expected model format | ||
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. | ||
|
||
## Using the `ORPOTrainer` | ||
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT. | ||
|
||
```py | ||
orpo_config = ORPOConfig( | ||
beta=0.1, # the lambda/alpha hyperparameter in the paper/code | ||
) | ||
|
||
orpo_trainer = ORPOTrainer( | ||
model, | ||
args=orpo_config, | ||
train_dataset=train_dataset, | ||
tokenizer=tokenizer, | ||
) | ||
``` | ||
After this one can then call: | ||
|
||
```py | ||
orpo_trainer.train() | ||
``` | ||
|
||
## Logging | ||
|
||
While training and evaluating we record the following reward metrics: | ||
|
||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta | ||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta | ||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards | ||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards | ||
|
||
* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses | ||
|
||
* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))` | ||
|
||
* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses | ||
|
||
## ORPOTrainer | ||
|
||
[[autodoc]] ORPOTrainer | ||
|
||
|
||
## ORPOConfig | ||
|
||
[[autodoc]] ORPOConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Run the ORPO training script with the following command with some example arguments. | ||
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model: | ||
# regular: | ||
python examples/scripts/orpo.py \ | ||
--model_name_or_path=gpt2 \ | ||
--per_device_train_batch_size 4 \ | ||
--max_steps 1000 \ | ||
--learning_rate 8e-6 \ | ||
--gradient_accumulation_steps 1 \ | ||
--logging_steps 10 \ | ||
--eval_steps 500 \ | ||
--output_dir="gpt2-aligned-orpo" \ | ||
--warmup_steps 150 \ | ||
--report_to wandb \ | ||
--bf16 \ | ||
--logging_first_step \ | ||
--no_remove_unused_columns | ||
# peft: | ||
python examples/scripts/orpo.py \ | ||
--model_name_or_path=gpt2 \ | ||
--per_device_train_batch_size 4 \ | ||
--max_steps 1000 \ | ||
--learning_rate 8e-5 \ | ||
--gradient_accumulation_steps 1 \ | ||
--logging_steps 10 \ | ||
--eval_steps 500 \ | ||
--output_dir="gpt2-lora-aligned-orpo" \ | ||
--optim rmsprop \ | ||
--warmup_steps 150 \ | ||
--report_to wandb \ | ||
--bf16 \ | ||
--logging_first_step \ | ||
--no_remove_unused_columns \ | ||
--use_peft \ | ||
--lora_r=16 \ | ||
--lora_alpha=16 | ||
""" | ||
|
||
import multiprocessing | ||
from dataclasses import dataclass, field | ||
|
||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | ||
|
||
from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config | ||
|
||
|
||
@dataclass | ||
class ScriptArguments: | ||
dataset: str = field( | ||
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."} | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) | ||
args, orpo_args, model_config = parser.parse_args_into_dataclasses() | ||
|
||
################ | ||
# Model & Tokenizer | ||
################ | ||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) | ||
peft_config = get_peft_config(model_config) | ||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
################ | ||
# Dataset | ||
################ | ||
ds = load_dataset(args.dataset) | ||
if orpo_args.debug: | ||
for key in ds: | ||
ds[key] = ds[key].select(range(50)) | ||
if tokenizer.chat_template is None: | ||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" | ||
|
||
def process(row): | ||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) | ||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) | ||
return row | ||
|
||
ds = ds.map( | ||
process, | ||
num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(), | ||
load_from_cache_file=False, | ||
) | ||
train_dataset = ds["train"] | ||
eval_dataset = ds["test"] | ||
|
||
################ | ||
# Training | ||
################ | ||
trainer = ORPOTrainer( | ||
model, | ||
args=orpo_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
tokenizer=tokenizer, | ||
peft_config=get_peft_config(model_config), | ||
) | ||
|
||
# train and save the model | ||
trainer.train() | ||
trainer.save_model(orpo_args.output_dir) |
Oops, something went wrong.