Skip to content

Commit

Permalink
ORPO trainer (#1435)
Browse files Browse the repository at this point in the history
* initial orpo skeleton

* typos

* calculate orpo loss

* fix class name

* fix tests

* fix typo

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <[email protected]>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <[email protected]>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <[email protected]>

* rename max_target_length

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <[email protected]>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <[email protected]>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <[email protected]>

* more docs

* log log_odds_ratio and log_odds

* average_log_prob as per paper

* added logging section

* add nll_loss

* fix typo

* more verbose

* rename log_odds to log_odds_chosen

* allow datasets to be loaded

* remove dup debug arg

* tokenizer exists

* fix typo

* use trl-internal-testing/hh-rlhf-trl-style dataset

* formatting

* add missing imports

* fix output dir name

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <[email protected]>

* move dataset_num_proc to configs

* Update trl/trainer/orpo_config.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update trl/trainer/orpo_trainer.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* add ORPOTrainer to readme

* fix typo

---------

Co-authored-by: lewtun <[email protected]>
Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2024
1 parent d1df79f commit 2ce8e45
Show file tree
Hide file tree
Showing 9 changed files with 1,424 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The library is built on top of the [`transformers`](https://github.com/huggingfa
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), and [`CPOTrainer`]((https://huggingface.co/docs/trl/trainer#trl.CPOTrainer).
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).

Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: text_environments
Expand Down
98 changes: 98 additions & 0 deletions docs/source/orpo_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# ORPO Trainer

[Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.

Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.

The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).

## Expected dataset format

The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:

- `prompt`
- `chosen`
- `rejected`

for example:

```py
orpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.

## Expected model format
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

## Using the `ORPOTrainer`
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.

```py
orpo_config = ORPOConfig(
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
)

orpo_trainer = ORPOTrainer(
model,
args=orpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:

```py
orpo_trainer.train()
```

## Logging

While training and evaluating we record the following reward metrics:

* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards

* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses

* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`

* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses

## ORPOTrainer

[[autodoc]] ORPOTrainer


## ORPOConfig

[[autodoc]] ORPOConfig
121 changes: 121 additions & 0 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the ORPO training script with the following command with some example arguments.
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:
# regular:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-orpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-orpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""

import multiprocessing
from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config


@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
args, orpo_args, model_config = parser.parse_args_into_dataclasses()

################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
ds = load_dataset(args.dataset)
if orpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]

################
# Training
################
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)

# train and save the model
trainer.train()
trainer.save_model(orpo_args.output_dir)
Loading

0 comments on commit 2ce8e45

Please sign in to comment.