-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* docs(*): update documentation * docs: update trainers & pipelines * style: satisfy flake * style: satisfy isort * docs: add github workflow * fix(base_trainer): revert deleting `logit_mask` * docs: add torchtyping to docs requirements * docs(workflow): force installation of all dependencies * docs(workflow): relax dependencies for python3.8 * docs: change project slug to stage * docs(workflow): set message template * docs: add installation, api & examples sections * style: satisfy black * feat(docs): add NeMo installation instructions and append w&b links * chore(readthedocs): update build config * chore(readthedocs): remove depreciated option * style: satisfy black * chore(readthedocs): unpin dependencies for python3.8 * chore(readthedocs): remove github workflow * feat(docs): add distributed instructions * fix(readthedocs): install torch separately * fix(readthedocs): install torch inside docs requirements * style: satisfy black
- Loading branch information
1 parent
db466cb
commit 69b02b0
Showing
19 changed files
with
552 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,4 @@ | ||
accelerate==0.12.0 | ||
datasets==2.4.0 | ||
deepspeed==0.7.3 | ||
einops==0.4.1 | ||
numpy==1.23.2 | ||
sphinx==4.0.0 | ||
sphinx_rtd_theme | ||
torch | ||
torchtyping | ||
tqdm==4.64.0 | ||
transformers==4.21.2 | ||
wandb==0.13.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
.. _api: | ||
|
||
API | ||
=== | ||
|
||
trlX uses a single entrypoint for training, which will execute training conditioned on the passed config and the necessary arguments for a specific training routine. For the online training `prompts` (a list of strings to prompt the training model) and `reward_fn` (a function which gives reward for model outputs sampled from `prompts`) are necessary, while for offline training `samples` (a list of environment/model interactions) and `rewards` (precomputed scores for each interaction) are required. | ||
|
||
Training | ||
-------- | ||
|
||
.. autofunction:: trlx.train | ||
|
||
Distributed | ||
----------- | ||
|
||
Accelerate | ||
^^^^^^^^^^ | ||
|
||
To launch distributed training with Accelerate, first you have to specify the training configuration. You only have to execute this command once per each training node. | ||
|
||
.. code-block:: console | ||
$ accelerate config | ||
$ accelerate launch examples/ppo_sentiments.py | ||
You can also use configs provided in `trlX repository <https://github.com/CarperAI/trlx/tree/main/configs/accelerate>`_): | ||
|
||
.. code-block:: console | ||
$ accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py | ||
NVIDIA NeMo | ||
^^^^^^^^^^^ | ||
|
||
For training with NeMo you have to use a model stored in the NeMo format. You can convert an existing llama model with the following script: | ||
|
||
.. code-block:: console | ||
$ python examples/llama_nemo/convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b | ||
To start training you have to execute python script per each GPU, or launch the following sbatch script which has `-ntasks-per-node=8` | ||
|
||
.. code-block:: console | ||
$ sbatch examples/llama_nemo/dist_train.sh | ||
Run example: `wandb <https://wandb.ai/carperai/trlxnemo/runs/v7592y73?workspace=user-pvduy>`_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,36 @@ | ||
.. _data: | ||
|
||
Data Elements | ||
************************ | ||
Data Classes | ||
============ | ||
|
||
All of the major Carper projects: trlX, CHEESE, and magiCARP use | ||
dataclasses corresponding to batches of data to communicate data between models and different | ||
components. trlX is no different, though it has many different dataclasses for | ||
different components like training or inference. Currently, we support PPO and ILQL, which | ||
each demand different kinds of data during training. | ||
Data Elements contain the necessary information for each individual training sample. | ||
|
||
PPO Data Classes | ||
---------------- | ||
|
||
**Basic Data Elements for Accelerate** | ||
|
||
.. autoclass:: trlx.data.accelerate_base_datatypes.PromptElement | ||
.. autoclass:: trlx.data.ppo_types.PPORLElement | ||
:members: | ||
|
||
.. autoclass:: trlx.data.accelerate_base_datatypes.PromptBatch | ||
.. autoclass:: trlx.data.ppo_types.PPORLBatch | ||
:members: | ||
|
||
.. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLElement | ||
:members: | ||
ILQL Data Classes | ||
----------------- | ||
|
||
.. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement | ||
.. autoclass:: trlx.data.ilql_types.ILQLElement | ||
:members: | ||
|
||
**Data Elements for PPO** | ||
|
||
.. autoclass:: trlx.data.ppo_types.PPORLElement | ||
.. autoclass:: trlx.models.modeling_ilql.CausalILQLOutput | ||
:members: | ||
|
||
.. autoclass:: trlx.data.ppo_types.PPORLBatch | ||
.. autoclass:: trlx.data.ilql_types.ILQLSeq2SeqElement | ||
:members: | ||
|
||
**Data Elements for ILQL** | ||
|
||
.. autoclass:: trlx.data.ilql_types.ILQLElement | ||
.. autoclass:: trlx.models.modeling_ilql.Seq2SeqILQLOutput | ||
:members: | ||
|
||
.. autoclass:: trlx.data.ilql_types.ILQLBatch | ||
:members: | ||
|
||
.. autoclass:: trlx.data.ilql_types.ILQLSeq2SeqBatch | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,128 @@ | ||
.. _examples: | ||
|
||
Examples | ||
************************ | ||
|
||
In the ``examples`` folder you can find several example training tasks. Check | ||
the configs folder for the associated configs files. ``examples.randomwalks`` | ||
does offline reinforcement on a set of graph random walks to stitch shortest | ||
paths to some destination. ``examples.simulacra`` optimizes prompts by using | ||
prompts-ratings dataset (https://github.com/JD-P/simulacra-aesthetic-captions). | ||
``examples.architext`` tries to optimize designs represented textually by | ||
minimazing number of rooms (pretrained model is under a license on hf). | ||
``examples.ilql_sentiments`` and ``examples.ppo_sentiments`` train to generate | ||
movie reviews with a positive sentiment, in offline setting – by fitting to IMDB | ||
dataset sentiment scores, and in online setting – by sampling finetuned on IMDB | ||
model and rating samples with learned sentiment reward model, You can tweak | ||
these scripts to your liking and tune hyperparameters to your problem if you | ||
wish to use trlx for some custom task. | ||
======== | ||
|
||
Random Walks | ||
------------ | ||
|
||
This is a simple toy example described in `Decision Transformer | ||
(Lili Chen et al. 2021) <https://arxiv.org/abs/2106.01345>`_. It's simple enough that it can be used for testing with a 1M sized LLM, training of which can complete entirely on CPU. | ||
|
||
Description | ||
^^^^^^^^^^^ | ||
|
||
The task is to find the shortest path on a directed graph. The reward is based | ||
on how optimal the path is compared to the shortest possible. Paths are | ||
represented as strings of letters, where each letter corresponds to a node in | ||
the graph. | ||
|
||
Training | ||
^^^^^^^^ | ||
|
||
For `PPO Training | ||
<https://github.com/CarperAI/trlx/blob/main/examples/randomwalks/ppo_randomwalks.py>`_, | ||
a language model continually samples paths in a graph and directly optimizes for | ||
their shortness using surrogate reward function. For `ILQL Training | ||
<https://github.com/CarperAI/trlx/blob/main/examples/randomwalks/ilql_randomwalks.py>`_ | ||
a language model learns directly from a set of 1000 pre-sampled randomwalks in a | ||
graph paired with their relative lengths' shortness. | ||
|
||
W&B runs: | ||
|
||
- PPO https://wandb.ai/sorry/trlx-references/runs/sf8ept0l | ||
- ILQL https://wandb.ai/sorry/trlx-references/runs/g44npaoq | ||
|
||
Positive Sentiment | ||
------------------ | ||
|
||
Description | ||
^^^^^^^^^^^ | ||
The task is to optimize a language model to generate positive sentiment responses for a given prompt. | ||
|
||
Training | ||
^^^^^^^^ | ||
|
||
The training is done by using `PPO trainer | ||
<https://github.com/CarperAI/trlx/blob/main/examples/ppo_sentiments.py>`_ to | ||
maximize a score from pre-trained sentiment classifier trained on IMDB review | ||
sentiments `dataset <https://huggingface.co/datasets/imdb>`_ . For `ILQL Training | ||
<https://github.com/CarperAI/trlx/blob/main/examples/ilql_sentiments.py>`_ the | ||
model is trained directly on the dataset and its labels: `0` for a negative | ||
review and `1` for a positive one. For `SFT Training | ||
<https://github.com/CarperAI/trlx/blob/main/examples/sft_sentiments.py>`_ the | ||
model is trained only on the positive reviews. | ||
|
||
W&B runs: | ||
|
||
- PPO: https://wandb.ai/sorry/trlx-references/runs/9ohlfd3s | ||
- ILQL: https://wandb.ai/sorry/trlx-references/runs/tplhaji6 | ||
- SFT: https://wandb.ai/sorry/trlx-references/runs/vfxfv081 | ||
|
||
Helpful & Harmless | ||
------------------- | ||
|
||
Description | ||
^^^^^^^^^^^ | ||
|
||
The task is to improve both helpfulness and harmlessness of the | ||
model's outputs following Anthropic's paper `Training a Helpful and Harmless | ||
Assistant with Reinforcement Learning from Human Feedback | ||
<https://arxiv.org/abs/2204.05862>`_ | ||
|
||
Training | ||
^^^^^^^^ | ||
|
||
The training is done by either utilizing a reward model trained on the | ||
Anthropic's Helpful & Harmless `dataset | ||
<https://github.com/anthropics/hh-rlhf>`_ using `PPO trainer | ||
<https://github.com/CarperAI/trlx/blob/main/examples/hh/ppo_hh.py>`_, or by | ||
using the dataset directly by reward labeling each selected and rejected with | ||
`+1` and `-1` respectively using `ILQL trainer | ||
<https://github.com/CarperAI/trlx/blob/main/examples/hh/ilql_hh.py>`_, or using | ||
`SFT trainer | ||
<https://github.com/CarperAI/trlx/blob/main/examples/hh/sft_hh.py>`_ and | ||
finetuning only over selected responses. | ||
|
||
The setup used for this example assumes a single machine with 8xA100 80GB, the | ||
last of which will be dedicated to hosting a reward model. Optionally you can | ||
use `Triton Inference Server <https://github.com/triton-inference-server>`_ to | ||
host it elsewhere, otherwise the training script will instantiate it (`a | ||
pretrained one <https://huggingface.co/Dahoas/gptj-rm-static>`_) on its own. | ||
|
||
Launch training of `GPT-J <https://huggingface.co/EleutherAI/gpt-j-6B>`_ on 7 | ||
GPUs with 8th GPU hosting a reward model: | ||
|
||
.. code-block:: console | ||
accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py | ||
# or for training from other predefined checkpoint | ||
CONFIG_NAME=125M accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py | ||
Optional steps to setup a reward model using Triton Server: | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. code-block:: console | ||
# convert the model and create a config and a folder `model_store` structured for Triton | ||
python to_triton.py --base_model EleutherAI/gpt-j-6B --checkpoint Dahoas/gptj-rm-static --revision 676bfd4d | ||
# convert the docker image (skip this if you use docker instead) | ||
singularity build --sandbox tritonserver-pyt.sif docker://nvcr.io/nvidia/tritonserver:22.08-pyt-python-py3 | ||
# start Triton Server pointing to the `model_store` containing the reward model | ||
SINGULARITYENV_CUDA_VISIBLE_DEVICES=7 singularity run --nv --bind model_store:/model_store tritonserver-pyt.sif tritonserver --model-repository=/model_store & | ||
Launch training: | ||
|
||
.. code-block:: console | ||
# set model's url and replace the name after the slash if you use a different checkpoint | ||
export TRITON_HOST=localhost:8001/gptj-rm-static | ||
accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py | ||
W&B runs: | ||
|
||
- PPO GPT-J: https://wandb.ai/sorry/trlx/runs/v0bir5s9 | ||
- ILQL GPT-J: https://wandb.ai/sorry/trlx/runs/1qqxp72a | ||
- SFT GPT-J: https://wandb.ai/sorry/trlx/runs/a7ng078v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,15 @@ | ||
.. trlX documentation master file, created by | ||
sphinx-quickstart on Mon Oct 3 21:21:33 2022. | ||
You can adapt this file completely to your liking, but it should at least | ||
contain the root `toctree` directive. | ||
Welcome to trlX's documentation! | ||
================================ | ||
trlX is a library made for training large language models using reinforcement learning. It | ||
currently supports training using PPO or ILQL for models up to 20B using Accelerate. | ||
trlX is a library for training large language models with reinforcement learning. Training can be done with two RL algorithms: PPO (`Schulman et al. 2017 <https://arxiv.org/abs/1707.06347>`_) for online training and ILQL (`Snell et al. 2022 <https://arxiv.org/abs/2206.11871>`_) for offline training. For distributed training two backends are supported: `Huggingface 🤗 Accelerate <https://github.com/huggingface/accelerate>`_ and `NVIDIA NeMo <https://nvidia.github.io/NeMo>`_. | ||
|
||
.. toctree:: | ||
:maxdepth: 2 | ||
:caption: Contents: | ||
|
||
data | ||
models | ||
configs | ||
pipeline | ||
installation | ||
api | ||
examples | ||
|
||
Indices and tables | ||
================== | ||
|
||
* :ref:`genindex` | ||
* :ref:`modindex` | ||
* :ref:`search` | ||
configs | ||
trainers | ||
pipelines | ||
data |
Oops, something went wrong.