diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 66fdc6b57b..1f793e6ea9 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -5,21 +5,55 @@
title: Installation
- local: quickstart
title: Quickstart
- - local: clis
- title: Get started with Command Line Interfaces (CLIs)
+ title: Getting started
+- sections:
- local: dataset_formats
title: Dataset Formats
- local: how_to_train
- title: PPO Training FAQ
- - local: use_model
- title: Use Trained Models
- - local: customization
- title: Customize the Training
+ title: Training FAQ
- local: logging
title: Understanding Logs
- title: Get started
+ title: Conceptual Guides
- sections:
- - sections: # Sort alphabetically
+ - local: clis
+ title: Command Line Interface (CLI)
+ - local: customization
+ title: Customizing the Training
+ - local: reducing_memory_usage
+ title: Reducing Memory Usage
+ - local: speeding_up_training
+ title: Speeding Up Training
+ - local: use_model
+ title: Using Trained Models
+ title: How-to guides
+- sections:
+ - local: deepspeed_integration
+ title: DeepSpeed
+ - local: liger_kernel_integration
+ title: Liger Kernel
+ - local: peft_integration
+ title: PEFT
+ - local: unsloth_integration
+ title: Unsloth
+ title: Integrations
+- sections:
+ - local: example_overview
+ title: Example Overview
+ - local: community_tutorials
+ title: Community Tutorials
+ - local: sentiment_tuning
+ title: Sentiment Tuning
+ - local: using_llama_models
+ title: Training StackLlama
+ - local: detoxifying_a_lm
+ title: Detoxifying a Language Model
+ - local: learning_tools
+ title: Learning to Use Tools
+ - local: multi_adapter_rl
+ title: Multi Adapter RLHF
+ title: Examples
+- sections:
+ - sections: # Sorted alphabetically
- local: alignprop_trainer
title: AlignProp
- local: bco_trainer
@@ -70,21 +104,3 @@
- local: script_utils
title: Script Utilities
title: API
-- sections:
- - local: community_tutorials
- title: Community Tutorials
- - local: example_overview
- title: Example Overview
- - local: sentiment_tuning
- title: Sentiment Tuning
- - local: lora_tuning_peft
- title: Training with PEFT
- - local: detoxifying_a_lm
- title: Detoxifying a Language Model
- - local: using_llama_models
- title: Training StackLlama
- - local: learning_tools
- title: Learning to Use Tools
- - local: multi_adapter_rl
- title: Multi Adapter RLHF
- title: Examples
diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx
index d76b5665da..a4c6b007ef 100644
--- a/docs/source/alignprop_trainer.mdx
+++ b/docs/source/alignprop_trainer.mdx
@@ -7,7 +7,7 @@
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
-
+
## Getting started with `examples/scripts/alignprop.py`
diff --git a/docs/source/community_tutorials.md b/docs/source/community_tutorials.md
index 4192012045..4b2b9a6e54 100644
--- a/docs/source/community_tutorials.md
+++ b/docs/source/community_tutorials.md
@@ -10,6 +10,7 @@ Community tutorials are made by active members of the Hugging Face community tha
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
+| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
@@ -18,9 +19,11 @@ Community tutorials are made by active members of the Hugging Face community tha
| Task | Class | Description | Author | Tutorial | Colab |
| --------------- | -------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------ | -------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
+| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
+| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
## Contributing
-If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
\ No newline at end of file
+If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
diff --git a/docs/source/ddpo_trainer.mdx b/docs/source/ddpo_trainer.mdx
index 20dbbe82b1..0682144edb 100644
--- a/docs/source/ddpo_trainer.mdx
+++ b/docs/source/ddpo_trainer.mdx
@@ -6,9 +6,9 @@
| Before | After DDPO finetuning |
| --- | --- |
-| | |
-| | |
-| | |
+| | |
+| | |
+| | |
## Getting started with Stable Diffusion finetuning with reinforcement learning
diff --git a/docs/source/deepspeed_integration.md b/docs/source/deepspeed_integration.md
new file mode 100644
index 0000000000..cbb3b53a37
--- /dev/null
+++ b/docs/source/deepspeed_integration.md
@@ -0,0 +1,7 @@
+# DeepSpeed Integration
+
+
+
+Section under construction. Feel free to contribute!
+
+
\ No newline at end of file
diff --git a/docs/source/detoxifying_a_lm.mdx b/docs/source/detoxifying_a_lm.mdx
index 4fb3741f43..fe97422889 100644
--- a/docs/source/detoxifying_a_lm.mdx
+++ b/docs/source/detoxifying_a_lm.mdx
@@ -83,7 +83,7 @@ As a compromise between the two we took for a context window of 10 to 15 tokens
-
+
### How to deal with OOM issues
@@ -101,7 +101,7 @@ and the optimizer will take care of computing the gradients in `bfloat16` precis
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
-
+
```python
@@ -124,13 +124,13 @@ We have decided to keep 3 models in total that correspond to our best models:
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
-
+
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
-
+
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
@@ -138,7 +138,7 @@ As you can see the model converges nicely, but obviously we don't observe a very
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
-
+
## Results
@@ -159,7 +159,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
@@ -167,7 +167,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
Below are few generation examples of `gpt-j-6b-detox` model:
-
+
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx
index 64ce87672f..103326fd9b 100644
--- a/docs/source/dpo_trainer.mdx
+++ b/docs/source/dpo_trainer.mdx
@@ -59,7 +59,7 @@ accelerate launch train_dpo.py
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/dpo-qwen2-reward-margin.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/dpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
diff --git a/docs/source/how_to_train.md b/docs/source/how_to_train.md
index bac324e43b..6ac55079b7 100644
--- a/docs/source/how_to_train.md
+++ b/docs/source/how_to_train.md
@@ -18,7 +18,7 @@ When training RL models, optimizing solely for reward may lead to unexpected beh
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
# TRL - Transformer Reinforcement Learning
diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx
index 7b79268410..05de7a026d 100644
--- a/docs/source/kto_trainer.mdx
+++ b/docs/source/kto_trainer.mdx
@@ -51,7 +51,7 @@ accelerate launch train_kto.py
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kto-qwen2-reward-margin.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kto-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
diff --git a/docs/source/learning_tools.mdx b/docs/source/learning_tools.mdx
index 7d693dd2c9..add4844e2b 100644
--- a/docs/source/learning_tools.mdx
+++ b/docs/source/learning_tools.mdx
@@ -69,7 +69,7 @@ The rough idea is as follows:
)
```
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
- ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png)
+ ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/learning_tools.png)
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
@@ -102,7 +102,7 @@ python -m openrlbenchmark.rlops_multi_metrics \
--scan-history
```
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/learning_tools_chart.png)
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
@@ -147,7 +147,7 @@ The frame of rackets for all sports was traditionally made of solid wood (later
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pyserini.png)
### Experiment settings
@@ -181,7 +181,7 @@ Q: """
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/triviaqa_learning_curves.png)
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
@@ -191,13 +191,13 @@ Note that the correct rate of the trained model is on the low end, which could b
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
- ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png)
+ ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/real_first_name.png)
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
- ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png)
+ ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/brown_act.png)
## (Early Experiments 🧪): solving math puzzles with python interpreter
@@ -230,4 +230,4 @@ Q: """
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/gms8k_learning_curve.png)
diff --git a/docs/source/liger_kernel_integration.md b/docs/source/liger_kernel_integration.md
new file mode 100644
index 0000000000..6a2cdf0e7e
--- /dev/null
+++ b/docs/source/liger_kernel_integration.md
@@ -0,0 +1,7 @@
+# Liger Kernel Integration
+
+
+
+Section under construction. Feel free to contribute!
+
+
\ No newline at end of file
diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md
index 881e57e69c..58fcca8c36 100644
--- a/docs/source/nash_md_trainer.md
+++ b/docs/source/nash_md_trainer.md
@@ -111,7 +111,7 @@ trainer.add_callback(completions_callback)
This callback logs the model's generated completions directly to Weights & Biases.
-![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
+![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png)
## Example script
diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md
index 49e40957c1..8ba147e780 100644
--- a/docs/source/online_dpo_trainer.md
+++ b/docs/source/online_dpo_trainer.md
@@ -51,7 +51,7 @@ accelerate launch train_online_dpo.py
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online-dpo-qwen2.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online-dpo-qwen2.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
@@ -110,7 +110,7 @@ trainer.add_callback(completions_callback)
This callback logs the model's generated completions directly to Weights & Biases.
-![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
+![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png)
## Example script
@@ -265,7 +265,7 @@ plt.tight_layout()
plt.show()
```
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online_dpo_scaling.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online_dpo_scaling.png)
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md
index 78a68e077c..bea95c485b 100644
--- a/docs/source/orpo_trainer.md
+++ b/docs/source/orpo_trainer.md
@@ -54,7 +54,7 @@ accelerate launch train_orpo.py
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/orpo-qwen2-reward-margin.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/orpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
diff --git a/docs/source/lora_tuning_peft.mdx b/docs/source/peft_integration.md
similarity index 98%
rename from docs/source/lora_tuning_peft.mdx
rename to docs/source/peft_integration.md
index 8906107c8e..01bef1e9e0 100644
--- a/docs/source/lora_tuning_peft.mdx
+++ b/docs/source/peft_integration.md
@@ -118,7 +118,7 @@ The `trl` library also supports naive pipeline parallelism (NPP) for large model
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
-
+
### How to use NPP?
diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md
index a1cdc6529b..1e0faf663f 100644
--- a/docs/source/ppo_trainer.md
+++ b/docs/source/ppo_trainer.md
@@ -66,7 +66,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif?download=true)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2_completions.gif)
In the logs the sampled generations look like
@@ -210,7 +210,7 @@ The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of t
Metrics:
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/ppov2.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2.png)
```bash
diff --git a/docs/source/prm_trainer.mdx b/docs/source/prm_trainer.mdx
index 012b8ec071..51813ca8d2 100644
--- a/docs/source/prm_trainer.mdx
+++ b/docs/source/prm_trainer.mdx
@@ -1,5 +1,7 @@
# PRM Trainer
+[![](https://img.shields.io/badge/All_models-PRM-blue)](https://huggingface.co/models?other=prm,trl)
+
PRM Trainer is an experimental API which is subject to change at any time.
diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx
index 6d653ef5f3..f310a101d8 100644
--- a/docs/source/quickstart.mdx
+++ b/docs/source/quickstart.mdx
@@ -9,7 +9,7 @@ Fine-tuning a language model via PPO consists of roughly three steps:
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
The full process is illustrated in the following figure:
-
+
## Minimal example
diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md
new file mode 100644
index 0000000000..dfe6dc5a7a
--- /dev/null
+++ b/docs/source/reducing_memory_usage.md
@@ -0,0 +1,54 @@
+# Reducing Memory Usage
+
+
+
+Section under construction. Feel free to contribute!
+
+
+
+## Truncation
+
+Sequence lengths in the dataset can vary widely, and by default, TRL does not modify the data. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
+
+
+
+
+
+To reduce memory usage, it’s important to truncate sequences to a reasonable length. Even discarding just a few tokens from the dataset can result in significant memory savings by minimizing unnecessary padding. Truncation is a good practice and should always be applied to ensure efficient use of resources. While the truncation limit doesn’t need to be overly restrictive, setting a sensible value is essential for optimal performance.
+
+
+
+
+DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
+
+
+
+
+
+To set the truncation parameters, use the following code snippet:
+
+```python
+from trl import DPOConfig
+
+training_args = DPOConfig(..., max_prompt_length=..., max_completion_length=..., max_length=...)
+```
+
+
+
+
+SFT truncation is applied to the input sequence via the `max_length` parameter.
+
+
+
+
+
+To set the truncation parameter, use the following code snippet:
+
+```python
+from trl import SFTConfig
+
+training_args = SFTConfig(..., max_length=...)
+```
+
+
+
\ No newline at end of file
diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md
index 71d189be7d..127f297321 100644
--- a/docs/source/rloo_trainer.md
+++ b/docs/source/rloo_trainer.md
@@ -68,7 +68,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2_completions.gif)
In the logs the sampled generations look like
@@ -251,7 +251,7 @@ The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of
Metrics:
-![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/rloo.png)
+![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/rloo.png)
```bash
diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx
index e83088ee09..6921946c89 100644
--- a/docs/source/sft_trainer.mdx
+++ b/docs/source/sft_trainer.mdx
@@ -502,7 +502,7 @@ NEFTune is a technique to boost the performance of chat models and was introduce
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
-
+
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
@@ -527,7 +527,7 @@ trainer.train()
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
-
+
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md
new file mode 100644
index 0000000000..3e026d422e
--- /dev/null
+++ b/docs/source/speeding_up_training.md
@@ -0,0 +1,7 @@
+# Speeding Up Training
+
+
+
+Section under construction. Feel free to contribute!
+
+
\ No newline at end of file
diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md
index 851020e0f5..c7b0bd0cfd 100644
--- a/docs/source/text_environments.md
+++ b/docs/source/text_environments.md
@@ -3,7 +3,7 @@
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
-
+
Let's dive into how text environments work and start with tools!
@@ -179,13 +179,13 @@ When the model interacts inside the `TextEnvironment` it can be useful to visual
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
-
+
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
-
+
Note that you can turn on the colour legend by passing `show_legend=True`.
diff --git a/docs/source/unsloth_integration.md b/docs/source/unsloth_integration.md
new file mode 100644
index 0000000000..dd071d392d
--- /dev/null
+++ b/docs/source/unsloth_integration.md
@@ -0,0 +1,7 @@
+# Unsloth Integration
+
+
+
+Section under construction. Feel free to contribute!
+
+
\ No newline at end of file
diff --git a/docs/source/using_llama_models.mdx b/docs/source/using_llama_models.mdx
index cf602d2030..420caf1948 100644
--- a/docs/source/using_llama_models.mdx
+++ b/docs/source/using_llama_models.mdx
@@ -19,7 +19,7 @@ Now we can fit very large models into a single GPU, but the training might still
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
-![chapter10_ddp.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_ddp.png)
+![chapter10_ddp.png](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/chapter10_ddp.png)
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
@@ -38,7 +38,7 @@ The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-
There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here.
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding.
-![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_preprocessing-clm.png)
+![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/chapter10_preprocessing-clm.png)
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
diff --git a/docs/source/xpo_trainer.mdx b/docs/source/xpo_trainer.mdx
index 7516b9218d..07a76f36dc 100644
--- a/docs/source/xpo_trainer.mdx
+++ b/docs/source/xpo_trainer.mdx
@@ -110,7 +110,7 @@ trainer.add_callback(completions_callback)
This callback logs the model's generated completions directly to Weights & Biases.
-![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
+![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png)
## Example script
diff --git a/examples/notebooks/gpt2-sentiment.ipynb b/examples/notebooks/gpt2-sentiment.ipynb
index 95f625f4f0..a5b6edc821 100644
--- a/examples/notebooks/gpt2-sentiment.ipynb
+++ b/examples/notebooks/gpt2-sentiment.ipynb
@@ -13,7 +13,7 @@
"metadata": {},
"source": [
"
\n",
- "\n",
+ "\n",
"
Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.
\n",
"
\n",
"\n",
diff --git a/examples/research_projects/stack_llama/scripts/rl_training.py b/examples/research_projects/stack_llama/scripts/rl_training.py
index 011c00554f..633cd53d05 100644
--- a/examples/research_projects/stack_llama/scripts/rl_training.py
+++ b/examples/research_projects/stack_llama/scripts/rl_training.py
@@ -20,9 +20,9 @@
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
-from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline
+from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed
-from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
+from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler
diff --git a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py
index d3998c7882..2684d31680 100644
--- a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py
+++ b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py
@@ -25,9 +25,10 @@
HfArgumentParser,
RobertaForSequenceClassification,
RobertaTokenizer,
+ set_seed,
)
-from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed
+from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model
from trl.core import LengthSampler
diff --git a/tests/test_core.py b/tests/test_core.py
index 88ecf38fad..16a6284753 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -16,7 +16,7 @@
import torch
-from trl.core import masked_mean, masked_var, masked_whiten, whiten
+from trl.core import masked_mean, masked_var, masked_whiten
class CoreTester(unittest.TestCase):
@@ -36,6 +36,10 @@ def test_masked_var(self):
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))
def test_masked_whiten(self):
+ def whiten(values: torch.Tensor) -> torch.Tensor:
+ mean, var = torch.mean(values), torch.var(values)
+ return (values - mean) * torch.rsqrt(var + 1e-8)
+
whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
diff --git a/trl/__init__.py b/trl/__init__.py
index 2df981580c..5598230781 100644
--- a/trl/__init__.py
+++ b/trl/__init__.py
@@ -21,7 +21,6 @@
_import_structure = {
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
- "core": ["set_seed"],
"data_utils": [
"apply_chat_template",
"extract_prompt",
@@ -115,7 +114,6 @@
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
if TYPE_CHECKING:
- from .core import set_seed
from .data_utils import (
apply_chat_template,
extract_prompt,
diff --git a/trl/core.py b/trl/core.py
index f62c8bc414..776ed5bdce 100644
--- a/trl/core.py
+++ b/trl/core.py
@@ -13,62 +13,14 @@
# limitations under the License.
import gc
-import random
import warnings
+from collections.abc import Mapping
from contextlib import contextmanager
from typing import Optional, Union
import numpy as np
import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.rnn import pad_sequence
-from transformers import TopKLogitsWarper, TopPLogitsWarper, is_torch_npu_available, is_torch_xpu_available
-
-
-try:
- from collections.abc import Mapping
-except ImportError:
- from collections.abc import Mapping
-
-
-WANDB_PADDING = -1
-
-
-def top_k_top_p_filtering(
- logits: torch.FloatTensor,
- top_k: int = 0,
- top_p: float = 1.0,
- filter_value: float = -float("Inf"),
- min_tokens_to_keep: int = 1,
-) -> torch.FloatTensor:
- """
- Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
-
- Args:
- logits: logits distribution shape (batch size, vocabulary size)
- top_k (`int`, *optional*, defaults to 0):
- If > 0, only keep the top k tokens with highest probability (top-k filtering)
- top_p (`float`, *optional*, defaults to 1.0):
- If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
- filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751)
- min_tokens_to_keep (`int`, *optional*, defaults to 1):
- Minimumber of tokens we keep per batch example in the output.
-
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
- """
-
- if top_k > 0:
- logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
- None, logits
- )
-
- if 0 <= top_p <= 1.0:
- logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
- None, logits
- )
-
- return logits
+from transformers import is_torch_npu_available, is_torch_xpu_available
def flatten_dict(nested: dict, sep: str = "/") -> dict:
@@ -88,52 +40,6 @@ def recurse(nest: dict, prefix: str, into: dict) -> None:
return flat
-def convert_to_scalar(stats: dict) -> dict:
- """
- Converts the stats from a flattened dict to single scalar dicts
- """
- tensorboard_stats = {}
- for k, v in stats.items():
- # for tensorboard compatibility - arrays and tensors are ignored with tensorboard
- # therefore we convert single element tensors to scalars
- if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (
- len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)
- ):
- v = v.item()
- tensorboard_stats[k] = v
- return tensorboard_stats
-
-
-def stack_dicts(stats_dicts: list[dict]) -> dict:
- """Stack the values of a dict."""
- results = dict()
- for k in stats_dicts[0]:
- stats_list = [torch.flatten(d[k]) for d in stats_dicts]
- results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
- return results
-
-
-def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
- """
- See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
- """
- logp = F.log_softmax(logits, dim=2)
-
- if not gather:
- return logp
- logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
- return logpy
-
-
-def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
- """Whiten values."""
- mean, var = torch.mean(values), torch.var(values)
- whitened = (values - mean) * torch.rsqrt(var + 1e-8)
- if not shift_mean:
- whitened += mean
- return whitened
-
-
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
@@ -170,73 +76,6 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T
return whitened
-def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
- """
- Tensor extension to torch.clamp
- https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
- """
- clipped = torch.max(torch.min(x, tensor_max), tensor_min)
- return clipped
-
-
-def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
- """Calculate entropy from logits."""
- pd = torch.nn.functional.softmax(logits, dim=-1)
- entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
- return entropy
-
-
-def stats_to_np(stats_dict: dict) -> dict:
- """Cast all torch.tensors in dict to numpy arrays."""
- new_dict = dict()
- for k, v in stats_dict.items():
- if isinstance(v, torch.Tensor):
- new_dict[k] = v.detach().cpu()
- if new_dict[k].dtype == torch.bfloat16:
- new_dict[k] = new_dict[k].float()
- new_dict[k] = new_dict[k].numpy()
- else:
- new_dict[k] = v
- if np.isscalar(new_dict[k]):
- new_dict[k] = float(new_dict[k])
- return new_dict
-
-
-def respond_to_batch(
- model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
-) -> torch.LongTensor:
- """Sample text from language model."""
- input_ids = queries
- for _i in range(txt_len):
- # Get Logits
- outputs = model(input_ids)
- next_token_logits = outputs[0][:, -1, :]
- next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
- # Sample
- probs = F.softmax(next_token_logits, dim=-1)
- next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
- input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
- return input_ids[:, -txt_len:]
-
-
-def set_seed(seed: int) -> None:
- """
- Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
-
- Args:
- seed (`int`): The seed to set.
- """
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if is_torch_xpu_available():
- torch.xpu.manual_seed_all(seed)
- elif is_torch_npu_available():
- torch.npu.manual_seed_all(seed)
- else:
- torch.cuda.manual_seed_all(seed)
-
-
class LengthSampler:
"""
Samples a length
diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py
index c0f02152c2..cf7b43ff25 100644
--- a/trl/extras/best_of_n_sampler.py
+++ b/trl/extras/best_of_n_sampler.py
@@ -15,9 +15,8 @@
from typing import Any, Callable, Optional, Union
import torch
-from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
+from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed
-from ..core import set_seed
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py
index 85a2e4d57c..9759609f56 100644
--- a/trl/trainer/__init__.py
+++ b/trl/trainer/__init__.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# There is a circular import in the PPOTrainer if we let isort sort these
from typing import TYPE_CHECKING
from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
@@ -21,7 +20,6 @@
_import_structure = {
"alignprop_config": ["AlignPropConfig"],
"alignprop_trainer": ["AlignPropTrainer"],
- "base": ["BaseTrainer"],
"bco_config": ["BCOConfig"],
"bco_trainer": ["BCOTrainer"],
"callbacks": [
@@ -41,8 +39,8 @@
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"judges": [
"AllTrueJudge",
- "BaseJudge",
"BaseBinaryJudge",
+ "BaseJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"HfPairwiseJudge",
@@ -60,23 +58,21 @@
"orpo_trainer": ["ORPOTrainer"],
"ppo_config": ["PPOConfig"],
"ppo_trainer": ["PPOTrainer"],
- "ppov2_config": ["PPOv2Config"],
- "ppov2_trainer": ["PPOv2Trainer"],
"prm_config": ["PRMConfig"],
"prm_trainer": ["PRMTrainer"],
"reward_config": ["RewardConfig"],
- "reward_trainer": ["RewardTrainer", "compute_accuracy"],
+ "reward_trainer": ["RewardTrainer"],
"rloo_config": ["RLOOConfig"],
"rloo_trainer": ["RLOOTrainer"],
"sft_config": ["SFTConfig"],
"sft_trainer": ["SFTTrainer"],
"utils": [
- "AdaptiveKLController",
"ConstantLengthDataset",
"DataCollatorForCompletionOnlyLM",
- "FixedKLController",
"RunningMoments",
+ "compute_accuracy",
"disable_dropout_in_model",
+ "empty_cache",
"peft_module_casting_to_bf16",
],
"xpo_config": ["XPOConfig"],
@@ -93,7 +89,6 @@
if TYPE_CHECKING:
from .alignprop_config import AlignPropConfig
from .alignprop_trainer import AlignPropTrainer
- from .base import BaseTrainer
from .bco_config import BCOConfig
from .bco_trainer import BCOTrainer
from .callbacks import (
@@ -135,17 +130,16 @@
from .prm_config import PRMConfig
from .prm_trainer import PRMTrainer
from .reward_config import RewardConfig
- from .reward_trainer import RewardTrainer, compute_accuracy
+ from .reward_trainer import RewardTrainer
from .rloo_config import RLOOConfig
from .rloo_trainer import RLOOTrainer
from .sft_config import SFTConfig
from .sft_trainer import SFTTrainer
from .utils import (
- AdaptiveKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
- FixedKLController,
RunningMoments,
+ compute_accuracy,
disable_dropout_in_model,
empty_cache,
peft_module_casting_to_bf16,
diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py
index c9cd718c44..7c5018892e 100644
--- a/trl/trainer/alignprop_trainer.py
+++ b/trl/trainer/alignprop_trainer.py
@@ -22,10 +22,11 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import PyTorchModelHubMixin
from transformers import is_wandb_available
from ..models import DDPOStableDiffusionPipeline
-from . import AlignPropConfig, BaseTrainer
+from .alignprop_config import AlignPropConfig
from .utils import generate_model_card, get_comet_experiment_url
@@ -35,7 +36,7 @@
logger = get_logger(__name__)
-class AlignPropTrainer(BaseTrainer):
+class AlignPropTrainer(PyTorchModelHubMixin):
"""
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
diff --git a/trl/trainer/base.py b/trl/trainer/base.py
deleted file mode 100644
index 7730e6af9a..0000000000
--- a/trl/trainer/base.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from huggingface_hub import PyTorchModelHubMixin
-
-
-class BaseTrainer(PyTorchModelHubMixin):
- r"""
- Base class for all trainers - this base class implements the basic functions that we
- need for a trainer.
-
- The trainer needs to have the following functions:
- - step: takes in a batch of data and performs a step of training
- - loss: takes in a batch of data and returns the loss
- - compute_rewards: takes in a batch of data and returns the rewards
- - _build_models_and_tokenizer: builds the models and tokenizer
- - _build_dataset: builds the dataset
- Each user is expected to implement their own trainer class that inherits from this base
- if they want to use a new training algorithm.
- """
-
- def __init__(self, config):
- self.config = config
-
- def step(self, *args):
- raise NotImplementedError("Not implemented")
-
- def loss(self, *args):
- raise NotImplementedError("Not implemented")
-
- def compute_rewards(self, *args):
- raise NotImplementedError("Not implemented")
-
- def _save_pretrained(self, save_directory):
- raise NotImplementedError("Not implemented")
diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py
index 01a1d0d5c5..846ab2730e 100644
--- a/trl/trainer/ddpo_trainer.py
+++ b/trl/trainer/ddpo_trainer.py
@@ -23,10 +23,11 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import PyTorchModelHubMixin
from transformers import is_wandb_available
from ..models import DDPOStableDiffusionPipeline
-from . import BaseTrainer, DDPOConfig
+from .ddpo_config import DDPOConfig
from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url
@@ -37,7 +38,7 @@
logger = get_logger(__name__)
-class DDPOTrainer(BaseTrainer):
+class DDPOTrainer(PyTorchModelHubMixin):
"""
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py
index f94522923b..50392526db 100644
--- a/trl/trainer/orpo_trainer.py
+++ b/trl/trainer/orpo_trainer.py
@@ -774,18 +774,12 @@ def cross_entropy_loss(logits, labels):
loss = loss_fct(logits, labels)
return loss
- if self.is_encoder_decoder:
- labels = concatenated_batch["concatenated_labels"].clone()
- else:
- labels = concatenated_batch["concatenated_input_ids"].clone()
- attention_mask = concatenated_batch["concatenated_attention_mask"]
- labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
-
+ labels = concatenated_batch["concatenated_labels"].clone()
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
all_logps = self.get_batch_logps(
all_logits,
- concatenated_batch["concatenated_labels"],
+ labels,
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
@@ -794,8 +788,12 @@ def cross_entropy_loss(logits, labels):
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
- chosen_logits = all_logits[:len_chosen]
- rejected_logits = all_logits[len_chosen:]
+ if not self.is_encoder_decoder:
+ chosen_logits = all_logits[:len_chosen, :-1, :]
+ rejected_logits = all_logits[len_chosen:, :-1, :]
+ else:
+ chosen_logits = all_logits[:len_chosen]
+ rejected_logits = all_logits[len_chosen:]
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 9d75d6138e..6d8ed96aa6 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -51,7 +51,6 @@
is_torch_xpu_available,
)
-from ..import_utils import is_unsloth_available
from ..trainer.model_config import ModelConfig
@@ -62,34 +61,6 @@
from peft import LoraConfig, PeftConfig
-class AdaptiveKLController:
- """
- Adaptive KL controller described in the paper:
- https://huggingface.co/papers/1909.08593
- """
-
- def __init__(self, init_kl_coef, target, horizon):
- self.value = init_kl_coef
- self.target = target
- self.horizon = horizon
-
- def update(self, current, n_steps):
- target = self.target
- proportional_error = np.clip(current / target - 1, -0.2, 0.2)
- mult = 1 + proportional_error * n_steps / self.horizon
- self.value *= mult
-
-
-class FixedKLController:
- """Fixed KL controller."""
-
- def __init__(self, kl_coef):
- self.value = kl_coef
-
- def update(self, current, n_steps):
- pass
-
-
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
@@ -878,24 +849,6 @@ def peft_module_casting_to_bf16(model):
module = module.to(torch.bfloat16)
-def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
- if is_unsloth_available():
- # Unsloth adds a new attribute in the model config `unsloth_version`
- # to keep track of models that have been patched with unsloth.
- if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None:
- tag_names.append("unsloth")
-
- if kwargs is not None:
- if "tags" not in kwargs:
- kwargs["tags"] = tag_names
- elif "tags" in kwargs and isinstance(kwargs["tags"], list):
- kwargs["tags"].extend(tag_names)
- elif "tags" in kwargs and isinstance(kwargs["tags"], str):
- tag_names.append(kwargs["tags"])
- kwargs["tags"] = tag_names
- return kwargs
-
-
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(