Skip to content

Commit

Permalink
Merge branch 'main' into dpo-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Mar 19, 2024
2 parents ddd134a + 66b043a commit 73e47f1
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 48 deletions.
115 changes: 82 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
</div>

# TRL - Transformer Reinforcement Learning
> Full stack transformer language models with reinforcement learning.
> Full stack library to fine-tune and align large language models.
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
Expand All @@ -20,58 +20,70 @@

## What is it?

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>

`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.

**Highlights:**
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).

- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.

## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:

1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
## Highlights

This process is illustrated in the sketch below:


<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
- **`Efficient and scalable`**:
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), and [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).

## Installation

### Python package
Install the library with pip:
Install the library with `pip`:
```bash
pip install trl
```

### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
If you want to use the latest features before an official release you can install from source:
```bash
pip install git+https://github.com/huggingface/trl.git
```

### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .
```

If you wish to develop TRL, you should install in editable mode:
## Command Line Interface (CLI)

You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:

**SFT:**

```bash
pip install -e .
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```

**DPO:**

```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
```

**Chat:**

```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```

Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.

## How to use

For more flexibility and control over the training you can use the dedicated trainer classes to fine-tune the model in Python.

### `SFTTrainer`

This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
Expand Down Expand Up @@ -161,13 +173,50 @@ reward = [torch.tensor(1.0)]
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```

### `DPOTrainer`

`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example on how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python
# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

## Development

If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:

```bash
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
```

## References

### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].

### Language models
The language models utilize the `transformers` library by 🤗 Hugging Face. Currently, `trl` only supports `transformers` models **for text**.
### Direct Preference Optimization
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](), [code](https://github.com/eric-mitchell/direct-preference-optimization)]


## Citation

Expand Down
6 changes: 3 additions & 3 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ You can pass any of these arguments either to the CLI or the YAML file.
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:

```bash
trl sft --config config.yaml --output_dir your-output-dir
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```

The SFT CLI is based on the `examples/scripts/sft.py` script.
Expand All @@ -78,7 +78,7 @@ These datasets always have at least three columns `prompt, chosen, rejected`:
To do a quick start, you can run the following command:

```bash
trl dpo --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style
```


Expand All @@ -98,7 +98,7 @@ python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
The chat CLI lets you quickly load the model and talk to it. Simply run the following:

```bash
trl chat --model Qwen/Qwen1.5-0.5B-Chat
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```

Note that the chat interface relies on the chat template of the tokenizer to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Expand Down
17 changes: 16 additions & 1 deletion docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,24 @@

TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).


The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.

## How DPO works

Fine-tuning a language model via DPO consists of two steps and is easier than PPO:

1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.

DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo).

This process is illustrated in the sketch below (from [figure 1 of the original paper](https://arxiv.org/pdf/2305.18290.pdf)):

<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">

Read more about DPO algorithm in the [original paper](https://arxiv.org/pdf/2305.18290.pdf).


## Expected dataset format

The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
Expand Down
15 changes: 15 additions & 0 deletions docs/source/ppo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@ TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training la

The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.

## How PPO works

Fine-tuning a language model via PPO consists of roughly three steps:

1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.

This process is illustrated in the sketch below:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>

## Expected dataset format

The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
Expand Down
17 changes: 9 additions & 8 deletions examples/scripts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def save_chat(chat, args, filename):
folder = args.save_folder

if filename is None:
filename = create_default_filename(args.model)
filename = create_default_filename(args.model_name_or_path)
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)

Expand Down Expand Up @@ -210,7 +210,9 @@ def parse_settings(user_input, current_args, interface):
return current_args, True


def load_model(args):
def load_model_and_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = get_quantization_config(args)
model_kwargs = dict(
Expand All @@ -221,12 +223,12 @@ def load_model(args):
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)

if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
print(model.device)
return model

return model, tokenizer


def chat_cli():
Expand All @@ -247,11 +249,10 @@ def chat_cli():
else:
user = args.user

tokenizer = AutoTokenizer.from_pretrained(args.model)
model = load_model(args)
model, tokenizer = load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

interface = RichInterface(model_name=args.model, user_name=user)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
while True:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from setuptools import find_packages, setup


__version__ = "0.7.12.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.8.1.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)

REQUIRED_PKGS = [
"torch>=1.4.0",
Expand Down
2 changes: 1 addition & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa

__version__ = "0.7.12.dev0"
__version__ = "0.8.1.dev0"

from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class DpoScriptArguments:
@dataclass
class ChatArguments:
# general settings
model: str = field(metadata={"help": "Name of the pre-trained model"})
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"})
user: str = field(default=None, metadata={"help": "Username to display in chat interface"})
system_prompt: str = field(default=None, metadata={"help": "System prompt"})
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"})
Expand Down

0 comments on commit 73e47f1

Please sign in to comment.