From 69b02b06ea002c1235651b1c0c412fbceeb0857b Mon Sep 17 00:00:00 2001 From: Max <56548574+maxreciprocate@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:08:03 +0300 Subject: [PATCH] docs: update documentation (#557) * docs(*): update documentation * docs: update trainers & pipelines * style: satisfy flake * style: satisfy isort * docs: add github workflow * fix(base_trainer): revert deleting `logit_mask` * docs: add torchtyping to docs requirements * docs(workflow): force installation of all dependencies * docs(workflow): relax dependencies for python3.8 * docs: change project slug to stage * docs(workflow): set message template * docs: add installation, api & examples sections * style: satisfy black * feat(docs): add NeMo installation instructions and append w&b links * chore(readthedocs): update build config * chore(readthedocs): remove depreciated option * style: satisfy black * chore(readthedocs): unpin dependencies for python3.8 * chore(readthedocs): remove github workflow * feat(docs): add distributed instructions * fix(readthedocs): install torch separately * fix(readthedocs): install torch inside docs requirements * style: satisfy black --- .readthedocs.yml | 13 ++- docs/requirements.txt | 9 +- docs/source/api.rst | 48 ++++++++ docs/source/configs.rst | 21 ++-- docs/source/data.rst | 37 +++---- docs/source/examples.rst | 140 +++++++++++++++++++++--- docs/source/index.rst | 25 ++--- docs/source/installation.rst | 56 ++++++++++ docs/source/pipeline.rst | 28 ----- docs/source/pipelines.rst | 32 ++++++ docs/source/trainer.rst | 25 ----- docs/source/trainers.rst | 37 +++++++ trlx/data/ilql_types.py | 72 ++++++++---- trlx/models/modeling_ilql.py | 82 +++++++++++++- trlx/pipeline/offline_pipeline.py | 12 +- trlx/trainer/__init__.py | 49 +-------- trlx/trainer/accelerate_base_trainer.py | 25 ++--- trlx/trainer/accelerate_ppo_trainer.py | 29 ++--- trlx/trlx.py | 49 ++++++--- 19 files changed, 552 insertions(+), 237 deletions(-) create mode 100644 docs/source/api.rst create mode 100644 docs/source/installation.rst delete mode 100644 docs/source/pipeline.rst create mode 100644 docs/source/pipelines.rst delete mode 100644 docs/source/trainer.rst create mode 100644 docs/source/trainers.rst diff --git a/.readthedocs.yml b/.readthedocs.yml index c8f03ab0a..6bfd60692 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,7 +3,16 @@ version: 2 sphinx: configuration: docs/source/conf.py +build: + os: ubuntu-22.04 + tools: + python: "3.8" + nodejs: "18" + rust: "1.64" + golang: "1.19" + python: - version: 3.9 install: - - requirements: docs/requirements.txt + - requirements: docs/requirements.txt + - method: pip + path: . diff --git a/docs/requirements.txt b/docs/requirements.txt index 7a33f300e..2470e301f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,11 +1,4 @@ -accelerate==0.12.0 -datasets==2.4.0 -deepspeed==0.7.3 -einops==0.4.1 -numpy==1.23.2 sphinx==4.0.0 sphinx_rtd_theme +torch torchtyping -tqdm==4.64.0 -transformers==4.21.2 -wandb==0.13.2 diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 000000000..f38797dd8 --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,48 @@ +.. _api: + +API +=== + +trlX uses a single entrypoint for training, which will execute training conditioned on the passed config and the necessary arguments for a specific training routine. For the online training `prompts` (a list of strings to prompt the training model) and `reward_fn` (a function which gives reward for model outputs sampled from `prompts`) are necessary, while for offline training `samples` (a list of environment/model interactions) and `rewards` (precomputed scores for each interaction) are required. + +Training +-------- + +.. autofunction:: trlx.train + +Distributed +----------- + +Accelerate +^^^^^^^^^^ + +To launch distributed training with Accelerate, first you have to specify the training configuration. You only have to execute this command once per each training node. + +.. code-block:: console + + $ accelerate config + $ accelerate launch examples/ppo_sentiments.py + +You can also use configs provided in `trlX repository `_): + +.. code-block:: console + + $ accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py + + +NVIDIA NeMo +^^^^^^^^^^^ + +For training with NeMo you have to use a model stored in the NeMo format. You can convert an existing llama model with the following script: + +.. code-block:: console + + $ python examples/llama_nemo/convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b + +To start training you have to execute python script per each GPU, or launch the following sbatch script which has `-ntasks-per-node=8` + +.. code-block:: console + + $ sbatch examples/llama_nemo/dist_train.sh + +Run example: `wandb `_ diff --git a/docs/source/configs.rst b/docs/source/configs.rst index da5e1f2e6..1d84a92db 100644 --- a/docs/source/configs.rst +++ b/docs/source/configs.rst @@ -3,21 +3,26 @@ Configs ************************ -Training a model in TRL will require you to set several configs: -ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like -training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for -the specific method being used (i.e. ILQL or PPO) - +Training requires configuration to be passed through a set of configs: `TrainConfig` with training configuration, `ModelConfig`, `TokenizerConfig`, `OptimizerConfig`, `SchedulerConfig` and a `MethodConfig` for a specific configuration of a particular algorithm (PPO, ILQL or SFT) **General** .. autoclass:: trlx.data.configs.TRLConfig :members: +.. autoclass:: trlx.data.configs.TrainConfig + :members: + .. autoclass:: trlx.data.configs.ModelConfig :members: -.. autoclass:: trlx.data.configs.TrainConfig +.. autoclass:: trlx.data.configs.TokenizerConfig + :members: + +.. autoclass:: trlx.data.configs.OptimizerConfig + :members: + +.. autoclass:: trlx.data.configs.SchedulerConfig :members: .. autoclass:: trlx.data.method_configs.MethodConfig @@ -25,10 +30,10 @@ the specific method being used (i.e. ILQL or PPO) **PPO** -.. autoclass:: trlx.data.method_configs.PPOConfig +.. autoclass:: trlx.models.modeling_ppo.PPOConfig :members: **ILQL** -.. autoclass:: trlx.data.method_configs.ILQLConfig +.. autoclass:: trlx.models.modeling_ilql.ILQLConfig :members: diff --git a/docs/source/data.rst b/docs/source/data.rst index 412e442ba..bb71da8f8 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -1,41 +1,36 @@ .. _data: -Data Elements -************************ +Data Classes +============ -All of the major Carper projects: trlX, CHEESE, and magiCARP use -dataclasses corresponding to batches of data to communicate data between models and different -components. trlX is no different, though it has many different dataclasses for -different components like training or inference. Currently, we support PPO and ILQL, which -each demand different kinds of data during training. +Data Elements contain the necessary information for each individual training sample. +PPO Data Classes +---------------- -**Basic Data Elements for Accelerate** - -.. autoclass:: trlx.data.accelerate_base_datatypes.PromptElement +.. autoclass:: trlx.data.ppo_types.PPORLElement :members: -.. autoclass:: trlx.data.accelerate_base_datatypes.PromptBatch +.. autoclass:: trlx.data.ppo_types.PPORLBatch :members: -.. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLElement - :members: +ILQL Data Classes +----------------- -.. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement +.. autoclass:: trlx.data.ilql_types.ILQLElement :members: -**Data Elements for PPO** - -.. autoclass:: trlx.data.ppo_types.PPORLElement +.. autoclass:: trlx.models.modeling_ilql.CausalILQLOutput :members: -.. autoclass:: trlx.data.ppo_types.PPORLBatch +.. autoclass:: trlx.data.ilql_types.ILQLSeq2SeqElement :members: -**Data Elements for ILQL** - -.. autoclass:: trlx.data.ilql_types.ILQLElement +.. autoclass:: trlx.models.modeling_ilql.Seq2SeqILQLOutput :members: .. autoclass:: trlx.data.ilql_types.ILQLBatch :members: + +.. autoclass:: trlx.data.ilql_types.ILQLSeq2SeqBatch + :members: diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 6f5db49d1..01ec1db66 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -1,18 +1,128 @@ .. _examples: Examples -************************ - -In the ``examples`` folder you can find several example training tasks. Check -the configs folder for the associated configs files. ``examples.randomwalks`` -does offline reinforcement on a set of graph random walks to stitch shortest -paths to some destination. ``examples.simulacra`` optimizes prompts by using -prompts-ratings dataset (https://github.com/JD-P/simulacra-aesthetic-captions). -``examples.architext`` tries to optimize designs represented textually by -minimazing number of rooms (pretrained model is under a license on hf). -``examples.ilql_sentiments`` and ``examples.ppo_sentiments`` train to generate -movie reviews with a positive sentiment, in offline setting – by fitting to IMDB -dataset sentiment scores, and in online setting – by sampling finetuned on IMDB -model and rating samples with learned sentiment reward model, You can tweak -these scripts to your liking and tune hyperparameters to your problem if you -wish to use trlx for some custom task. +======== + +Random Walks +------------ + +This is a simple toy example described in `Decision Transformer +(Lili Chen et al. 2021) `_. It's simple enough that it can be used for testing with a 1M sized LLM, training of which can complete entirely on CPU. + +Description +^^^^^^^^^^^ + +The task is to find the shortest path on a directed graph. The reward is based +on how optimal the path is compared to the shortest possible. Paths are +represented as strings of letters, where each letter corresponds to a node in +the graph. + +Training +^^^^^^^^ + +For `PPO Training +`_, +a language model continually samples paths in a graph and directly optimizes for +their shortness using surrogate reward function. For `ILQL Training +`_ +a language model learns directly from a set of 1000 pre-sampled randomwalks in a +graph paired with their relative lengths' shortness. + +W&B runs: + +- PPO https://wandb.ai/sorry/trlx-references/runs/sf8ept0l +- ILQL https://wandb.ai/sorry/trlx-references/runs/g44npaoq + +Positive Sentiment +------------------ + +Description +^^^^^^^^^^^ +The task is to optimize a language model to generate positive sentiment responses for a given prompt. + +Training +^^^^^^^^ + +The training is done by using `PPO trainer +`_ to +maximize a score from pre-trained sentiment classifier trained on IMDB review +sentiments `dataset `_ . For `ILQL Training +`_ the +model is trained directly on the dataset and its labels: `0` for a negative +review and `1` for a positive one. For `SFT Training +`_ the +model is trained only on the positive reviews. + +W&B runs: + +- PPO: https://wandb.ai/sorry/trlx-references/runs/9ohlfd3s +- ILQL: https://wandb.ai/sorry/trlx-references/runs/tplhaji6 +- SFT: https://wandb.ai/sorry/trlx-references/runs/vfxfv081 + +Helpful & Harmless +------------------- + +Description +^^^^^^^^^^^ + +The task is to improve both helpfulness and harmlessness of the +model's outputs following Anthropic's paper `Training a Helpful and Harmless +Assistant with Reinforcement Learning from Human Feedback +`_ + +Training +^^^^^^^^ + +The training is done by either utilizing a reward model trained on the +Anthropic's Helpful & Harmless `dataset +`_ using `PPO trainer +`_, or by +using the dataset directly by reward labeling each selected and rejected with +`+1` and `-1` respectively using `ILQL trainer +`_, or using +`SFT trainer +`_ and +finetuning only over selected responses. + +The setup used for this example assumes a single machine with 8xA100 80GB, the +last of which will be dedicated to hosting a reward model. Optionally you can +use `Triton Inference Server `_ to +host it elsewhere, otherwise the training script will instantiate it (`a +pretrained one `_) on its own. + +Launch training of `GPT-J `_ on 7 +GPUs with 8th GPU hosting a reward model: + +.. code-block:: console + + accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py + # or for training from other predefined checkpoint + CONFIG_NAME=125M accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py + +Optional steps to setup a reward model using Triton Server: +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: console + + # convert the model and create a config and a folder `model_store` structured for Triton + python to_triton.py --base_model EleutherAI/gpt-j-6B --checkpoint Dahoas/gptj-rm-static --revision 676bfd4d + + # convert the docker image (skip this if you use docker instead) + singularity build --sandbox tritonserver-pyt.sif docker://nvcr.io/nvidia/tritonserver:22.08-pyt-python-py3 + + # start Triton Server pointing to the `model_store` containing the reward model + SINGULARITYENV_CUDA_VISIBLE_DEVICES=7 singularity run --nv --bind model_store:/model_store tritonserver-pyt.sif tritonserver --model-repository=/model_store & + +Launch training: + +.. code-block:: console + + # set model's url and replace the name after the slash if you use a different checkpoint + export TRITON_HOST=localhost:8001/gptj-rm-static + accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py + +W&B runs: + +- PPO GPT-J: https://wandb.ai/sorry/trlx/runs/v0bir5s9 +- ILQL GPT-J: https://wandb.ai/sorry/trlx/runs/1qqxp72a +- SFT GPT-J: https://wandb.ai/sorry/trlx/runs/a7ng078v diff --git a/docs/source/index.rst b/docs/source/index.rst index 1b2947593..04afbf272 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,26 +1,15 @@ -.. trlX documentation master file, created by - sphinx-quickstart on Mon Oct 3 21:21:33 2022. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - Welcome to trlX's documentation! ================================ -trlX is a library made for training large language models using reinforcement learning. It -currently supports training using PPO or ILQL for models up to 20B using Accelerate. +trlX is a library for training large language models with reinforcement learning. Training can be done with two RL algorithms: PPO (`Schulman et al. 2017 `_) for online training and ILQL (`Snell et al. 2022 `_) for offline training. For distributed training two backends are supported: `Huggingface 🤗 Accelerate `_ and `NVIDIA NeMo `_. .. toctree:: :maxdepth: 2 :caption: Contents: - data - models - configs - pipeline + installation + api examples - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + configs + trainers + pipelines + data diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 000000000..29e05b5b3 --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,56 @@ +.. _installation: + +Installation +============ + +trlX is a pure Python library that supports two optional distributed backends: `Huggingface 🤗 Accelerate `_ and `NVIDIA NeMo `_, the latter is optional and can be installed separately. + +Requirements +------------ + +* OS: Linux +* Python: 3.9-3.11 + +Install with pip +---------------- + +You can install trlX using pip: + +.. code-block:: console + + $ pip install -U git+https://github.com/CarperAI/trlx.git + +.. _build_from_source: + +Install from source +------------------- + +You can also install trlX from source: + +.. code-block:: console + + $ git clone https://github.com/CarperAI/trlx.git + $ cd trlx + $ pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 + $ pip install -e . + +Install NeMo +____________ + +Install NeMo version v1.17.0: + +.. code-block:: console + + $ git clone https://github.com/NVIDIA/NeMo/ + $ cd NeMo + $ git checkout d3017e4 + $ pip install -e '.[all]' + +Install Apex: + +.. code-block:: console + + $ git clone https://github.com/NVIDIA/apex + $ cd apex + $ # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... + $ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst deleted file mode 100644 index 68279d889..000000000 --- a/docs/source/pipeline.rst +++ /dev/null @@ -1,28 +0,0 @@ -.. _pipeline: - -Pipelines -************************ - -Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created -for them. It is these experiences in their rollout store that they are trained on. - -**General** - -.. autoclass:: trlx.pipeline.BasePipeline - :members: - -.. autoclass:: trlx.pipeline.BaseRolloutStore - :members: - -**PPO** - -.. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage - :members: - -**ILQL** - -.. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline - :members: - -.. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage - :members: diff --git a/docs/source/pipelines.rst b/docs/source/pipelines.rst new file mode 100644 index 000000000..da0e21a39 --- /dev/null +++ b/docs/source/pipelines.rst @@ -0,0 +1,32 @@ +.. _pipeline: + +Pipelines +========= + +Pipelines are used for accumulation and convertion of the training data to appropriate format. + +.. autoclass:: trlx.pipeline.BasePipeline + :members: + +.. autoclass:: trlx.pipeline.BaseRolloutStore + :members: + +.. autoclass:: trlx.pipeline.offline_pipeline.DialogMessage + :members: + +.. autoclass:: trlx.pipeline.offline_pipeline.DialogStore + :members: + +.. autofunction:: trlx.pipeline.offline_pipeline.tokenize_dialogue + +.. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage + :members: + +.. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline + :members: + +.. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage + :members: + +.. autoclass:: trlx.pipeline.offline_pipeline.ILQLSeq2SeqRolloutStorage + :members: diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst deleted file mode 100644 index 6259c8b21..000000000 --- a/docs/source/trainer.rst +++ /dev/null @@ -1,25 +0,0 @@ -.. _trainers: - -RL Trainers -******************* - -RL Trainers are what you're training with trlX. Currently, we support PPO and ILQL. -Note that new trainers must be registered with ``trlx.trainer.register_trainer``. - -**General** - -.. autoclass:: trlx.trainer.BaseRLTrainer - :members: - -.. autoclass:: trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer - :members: - -**PPO** - -.. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer - :members: - -**ILQL** - -.. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer - :members: diff --git a/docs/source/trainers.rst b/docs/source/trainers.rst new file mode 100644 index 000000000..7f45b3b40 --- /dev/null +++ b/docs/source/trainers.rst @@ -0,0 +1,37 @@ +.. _trainers: + +Trainers +======== + +Abstract Trainers +----------------- + +.. autoclass:: trlx.trainer.BaseRLTrainer + :members: + +.. autoclass:: trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer + :members: + +Accelerate Trainers +------------------- + +.. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer + :members: + +.. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer + :members: + +.. autoclass:: trlx.trainer.accelerate_sft_trainer.AccelerateSFTTrainer + :members: + +NeMo Trainers +------------- + +.. autoclass:: trlx.trainer.nemo_ppo_trainer.NeMoPPOTrainer + :members: + +.. autoclass:: trlx.trainer.nemo_ilql_trainer.NeMoILQLTrainer + :members: + +.. autoclass:: trlx.trainer.nemo_sft_trainer.NeMoSFTTrainer + :members: diff --git a/trlx/data/ilql_types.py b/trlx/data/ilql_types.py index cb83309d3..9d75249e9 100644 --- a/trlx/data/ilql_types.py +++ b/trlx/data/ilql_types.py @@ -1,33 +1,30 @@ -from dataclasses import dataclass, fields +from dataclasses import dataclass from torchtyping import TensorType # type: ignore -def flatten_dataclass(cls: type): - """Return a function that flattens a dataclass into a list""" - cls_fields = [f.name for f in fields(cls)] - return lambda x: [getattr(x, f) for f in cls_fields] - - -def unflatten_dataclass(cls: type): - """Return a function that unflattens a list into a dataclass""" - cls_fields = [f.name for f in fields(cls)] - return lambda x: cls(**dict(zip(cls_fields, x))) - - @dataclass class ILQLElement: """ - Data element for ILQL + A single data item for ILQL training - :param input_ids: Input tokens. Should be a long tensor. + :param input_ids: Long tensor of input tokens. :type input_ids: torch.Tensor - :param attention_mask: Attention mask. Should be a long tensor. + :param attention_mask: Attention mask for input tokens. Should be a long tensor. :type attention_mask: torch.Tensor - :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. + :param rewards: Rewards for each input token. :type rewards: torch.Tensor + + :param states_ixs: Indices of states (user input or environment input for example) in the `input_ids`. + :type states_ixs: torch.Tensor + + :param actions_ixs: Indices of actions (model output) in the `input_ids` tensor. + :type actions_ixs: torch.Tensor + + :param dones: Indicator of for the terminal state (end of episode) in the `input_ids` tensor. + :type dones: torch.Tensor """ input_ids: TensorType["query_size"] @@ -41,16 +38,28 @@ class ILQLElement: @dataclass class ILQLSeq2SeqElement: """ - Data element for ILQL + A single data item for ILQL training - :param input_ids: Input tokens. Should be a long tensor. + :param input_ids: Long tensor of input tokens. :type input_ids: torch.Tensor - :param attention_mask: Attention mask. Should be a long tensor. + :param attention_mask: Attention mask for input tokens. Should be a long tensor. :type attention_mask: torch.Tensor - :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. + :param decoder_input_ids: Long tensor of target input tokens. + :type decoder_input_ids: torch.Tensor + + :param rewards: Rewards for each input token. :type rewards: torch.Tensor + + :param states_ixs: Indices of states (user input or environment input for example) in the `input_ids`. + :type states_ixs: torch.Tensor + + :param actions_ixs: Indices of actions (model output) in the `input_ids` tensor. + :type actions_ixs: torch.Tensor + + :param dones: Indicator of for the terminal state (end of episode) in the `input_ids` tensor. + :type dones: torch.Tensor """ input_ids: TensorType["query_size"] @@ -75,6 +84,15 @@ class ILQLBatch: :param rewards: Batch of rewards for each token in each token batch. :type rewards: torch.Tensor + + :param states_ixs: Batch of indices of states (user input or environment input for example) in the `input_ids`. + :type states_ixs: torch.Tensor + + :param actions_ixs: Batch of indices of actions (model output) in the `input_ids` tensor. + :type actions_ixs: torch.Tensor + + :param dones: Batch of indicators of for the terminal state (end of episode) in the `input_ids` tensor. + :type dones: torch.Tensor """ input_ids: TensorType["batch_size", "query_size"] @@ -96,8 +114,20 @@ class ILQLSeq2SeqBatch: :param attention_mask: Batch of attention masks. :type attention_mask: torch.Tensor + :param decoder_input_ids: Batch of target input tokens. + :type decoder_input_ids: torch.Tensor + :param rewards: Batch of rewards for each token in each token batch. :type rewards: torch.Tensor + + :param states_ixs: Batch of indices of states (user input or environment input for example) in the `input_ids`. + :type states_ixs: torch.Tensor + + :param actions_ixs: Batch of indices of actions (model output) in the `input_ids` tensor. + :type actions_ixs: torch.Tensor + + :param dones: Batch of indicators of for the terminal state (end of episode) in the `input_ids` tensor. + :type dones: torch.Tensor """ input_ids: TensorType["batch_size", "query_size"] diff --git a/trlx/models/modeling_ilql.py b/trlx/models/modeling_ilql.py index 3aa9933ac..e3c0d3f2e 100644 --- a/trlx/models/modeling_ilql.py +++ b/trlx/models/modeling_ilql.py @@ -48,13 +48,46 @@ def batched_index_select( @dataclass @register_method class ILQLConfig(MethodConfig): + """ + Configuration for ILQL method. + + :param tau: Parameter for expectile regression for the value function to q + estimates, \in (0, 1), where tau=0.5 is equivalent to the mean square error + and tau=1 is equivalent to taking a maximum over q estimates + :type tau: float + + :param gamma: Discount factor + :type gamma: float + + :param cql_scale: Scale for the CQL loss (conservative q-learning loss) + :type cql_scale: float + + :param awac_scale: Scale for the AWAC loss (weighted cross-entropy loss) + :type awac_scale: float + + :param alpha: Parameter for Polyak averaging of the target Q-head sync, \in (0, 1) + :type alpha: float + + :param beta: Parameter for magnitude of weighting effect in the AWAC loss, \in (0, 1) + :type beta: float + + :param steps_for_target_q_sync: Number of steps between target Q-head syncs + :type steps_for_target_q_sync: int + + :param two_qs: Whether to use two Q-heads and taking minimum of separate estimates or using only one + :type two_qs: bool + + :param gen_kwargs: Keyword arguments for the generation method + :type gen_kwargs: dict + """ + tau: float gamma: float cql_scale: float awac_scale: float alpha: float beta: float - steps_for_target_q_sync: float + steps_for_target_q_sync: int two_qs: bool gen_kwargs: dict @@ -196,6 +229,28 @@ def sync_target_q_heads(self): @dataclass class CausalILQLOutput(ModelOutput): + """ + Output of the causal model with ILQL heads. + + :param logits: Logits of the causal model. + :type logits: torch.FloatTensor + + :param past_key_values: Tuple of past key values of the causal model. + :type past_key_values: Tuple[Tuple[torch.FloatTensor]] + + :param hidden_states: Last hidden state of the causal model. + :type hidden_states: Tuple[torch.FloatTensor] + + :param value: Value function estimation for each token in the input sequence. + :type value: torch.FloatTensor + + :param qs: Q-function estimations for each token in the input sequence. + :type qs: Tuple[torch.FloatTensor] + + :param target_qs: Q-function estimations from the target Q-head for each token in the input sequence. + :type target_qs: Tuple[torch.FloatTensor] + """ + logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -389,6 +444,31 @@ def post_init(self, state_dict): @dataclass class Seq2SeqILQLOutput(ModelOutput): + """ + Output of the seq2seq model with ILQL heads. + + :param logits: Logits of the seq2seq model. + :type logits: torch.FloatTensor + + :param past_key_values: Tuple of past key values of the seq2seq model. + :type past_key_values: Tuple[Tuple[torch.FloatTensor]] + + :param hidden_states: Last hidden state of the seq2seq model. + :type hidden_states: Tuple[torch.FloatTensor] + + :param value: Value function estimation for each token in the input sequence. + :type value: torch.FloatTensor + + :param qs: Q-function estimations for each token in the input sequence. + :type qs: Tuple[torch.FloatTensor] + + :param target_qs: Q-function estimations from the target Q-head for each token in the input sequence. + :type target_qs: Tuple[torch.FloatTensor] + + :param encoder_outputs: Tuple of encoder outputs of the seq2seq model. + :type encoder_outputs: Tuple[Any] + """ + logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index cee900cfc..fa978da00 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -21,6 +21,16 @@ @dataclass class DialogMessage: + """ + Single message in a dialogue + + :param is_output: Whether the message is a model output or a prompt + :type is_output: bool + + :param tokens: Tokenized message + :type tokens: Tuple[int] + """ + is_output: bool tokens: Tuple[int] @@ -241,7 +251,7 @@ def ilql_seq2seq_collate_fn(elems: Iterable[ILQLElement]): class ILQLSeq2SeqRolloutStorage(BaseRolloutStore): """ - Rollout storage for training ILQL + Rollout storage for training ILQL with Seq2Seq models """ def __init__(self, input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones): diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 8e0d239df..ffb42cf7d 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -46,58 +46,19 @@ def __init__( self.config = config self.reward_fn = reward_fn self.metric_fn = metric_fn - self.train_mode = train_mode self.logit_mask = logit_mask + self.train_mode = train_mode self.stop_sequences = stop_sequences def push_to_store(self, data): - self.store.push(data) - - def add_eval_pipeline(self, eval_pipeline): - """Adds pipeline for validation prompts""" - self.eval_pipeline = eval_pipeline - - @abstractmethod - def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterable[str]: """ - Sample from the language. Takes prompts and maximum length to generate. - - :param prompts: List of prompts to tokenize and use as context - - :param length: How many new tokens to genrate for each prompt - :type length: int - - :param n_samples: Default behavior is to take number of prompts as this + Append new data to the rollout store """ - pass + self.store.push(data) @abstractmethod - def learn( - self, - log_fn: Callable = None, - save_fn: Callable = None, - eval_fn: Callable = None, - ): + def learn(self): """ - Use experiences in RolloutStore to learn - - :param log_fn: Optional function that is called when logging and passed a dict of logging relevant values - :type log_fn: Callable[Dict[str, any]] - - :param save_fn: Optional function to call after saving. Is passed the components. - :type save_fn: Callable[Dict[str, any]] - - :param eval_fn: Optional function to call during evaluation. Eval doesn't do anything without this. - :type eval_fn: Callable[BaseRLTrainer] + Use data in the the rollout store to update the model """ pass - - @abstractmethod - def save(self, directory: Optional[str] = None): - """Creates a checkpoint of training states""" - pass - - @abstractmethod - def load(self, directory=None): - """Loads a checkpoint created from `save`""" - pass diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 9dd1f99a3..e58254ef9 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -40,7 +40,7 @@ @register_trainer class AccelerateRLTrainer(BaseRLTrainer): """ - RL model trainer with an `accelerate` based backend + Asbtract Trainer that uses `accelerate` backend """ def __init__(self, config, **kwargs): # noqa: C901 @@ -204,7 +204,7 @@ def decode( append_eos_token: bool = False, ) -> Tuple[List[str], List[str], List[str]]: """ - Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) + Decodes tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) """ if prompt_sizes is None: # Assuming prompts were left-padded @@ -250,7 +250,7 @@ def decode( return str_samples, str_prompts, str_outputs def generate(self, input_ids, attention_mask=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" + """Generate samples for the experience buffer using method's specific `self.generate_experience_kwargs`""" input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) @@ -265,7 +265,7 @@ def generate(self, input_ids, attention_mask=None, **kwargs): ) def generate_eval(self, input_ids, attention_mask=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" + """Generate samples for evaluation using `self.generate_kwargs`""" input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) @@ -278,8 +278,7 @@ def generate_eval(self, input_ids, attention_mask=None, **kwargs): ) def save_pretrained(self, directory: Optional[str] = None, **kwargs): - """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for - later use. + """Save the underlying model, tokenizer, and configuration files to a directory Args: directory (str, *optional*): The directory to save the trainer files to. @@ -304,7 +303,7 @@ def save_pretrained(self, directory: Optional[str] = None, **kwargs): self.tokenizer.save_pretrained(directory) def save(self, directory: Optional[str] = None, **kwargs): - """Creates a checkpoint of the optimizer, scheduler and model""" + """Creates a checkpoint for the optimizer, scheduler and the model""" dst_dir = directory or self.config.train.checkpoint_dir self.accelerator.save_state(dst_dir, **kwargs) @@ -317,7 +316,7 @@ def save(self, directory: Optional[str] = None, **kwargs): self.accelerator.unwrap_model(self.model).save_pretrained(dst_dir) def load(self, directory: Optional[str] = None, **kwargs): - """Load checkpoint of optimizer, scheduler and a model""" + """Loads the checkpoint of the optimizer, scheduler and the model""" if self.config.model.peft_config is not None: def load_state_hook(models: List[torch.nn.Module], input_dir: str): @@ -330,11 +329,11 @@ def load_state_hook(models: List[torch.nn.Module], input_dir: str): self.accelerator.load_state(directory or self.config.train.checkpoint_dir, **kwargs) def add_eval_pipeline(self, eval_pipeline): - """Adds pipeline from with validation prompts""" + """Adds a evalution pipeline with validation prompts""" self.eval_pipeline = eval_pipeline def evaluate(self): # noqa: C901 - """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" + """Samples model using `eval_prompts`, computes statistics with `reward_fn` and `metric_fn`""" logger.info("Evaluating model") # Do multiple evaluations over a single list in `gen_kwargs` if present @@ -655,12 +654,12 @@ def create_train_dataloader(self): @abstractmethod def get_arch(self, config: TRLConfig): - """Returns a specific wrapper of the decoder architecture""" + """Returns a specific wrapper given a model's architecture""" pass @abstractmethod def loss(self, batch) -> Tuple[float, Dict]: - """Compute loss on a batch from `store` and return some statistics""" + """Computes loss on a batch of data and returns statistics""" pass @abstractmethod @@ -675,5 +674,5 @@ def post_backward_callback(self): @abstractmethod def post_epoch_callback(self): - """Do something after exhausting/single pass over `self.store`""" + """Do something after a single pass over data from `self.store`""" pass diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 27ed4b5aa..cd0b62ab6 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -2,7 +2,7 @@ import os import uuid from time import time -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch @@ -43,7 +43,8 @@ def __init__(self, config: TRLConfig, **kwargs): """PPO Accelerate Trainer initialization Args: - config: Config + config: `TRLConfig` + kwargs: Additional keyword arguments passed to `AccelerateRLTrainer` """ super().__init__(config, **kwargs) @@ -105,7 +106,7 @@ def __init__(self, config: TRLConfig, **kwargs): self.ref_std = self.config.method.ref_std def get_arch(self, config: TRLConfig): - """Get the model""" + """Returns a specific wrapper given a model's architecture""" model_class = AutoModelForCausalLMWithHydraValueHead if config.model.model_arch_type == "seq2seq": model_class = AutoModelForSeq2SeqLMWithHydraValueHead @@ -122,11 +123,15 @@ def get_arch(self, config: TRLConfig): peft_config=self.config.model.peft_config, ) - def loss(self, batch: PPORLBatch): - """Forward pass & loss + def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]: + """Computes loss on a batch of data and returns statistics Args: - batch: Previous batch of episodes + batch: `PPORLBatch` Previous batch of episodes + + Returns: + loss: `Float` Loss value + stats: `Dict[str, Any]` PPO Statistics values """ # Move `batch` data to `accelerator` device query_tensors = batch.query_tensors.to(self.accelerator.device) @@ -198,7 +203,7 @@ def loss(self, batch: PPORLBatch): return loss, stats def setup_rollout_logging(self, config): - # Make rollout logging dir for this run and store config + """Make rollout logging directory to log rollouts to""" exists = os.path.exists(config.train.rollout_logging_dir) isdir = os.path.isdir(config.train.rollout_logging_dir) assert exists and isdir @@ -211,10 +216,7 @@ def setup_rollout_logging(self, config): f.write(json.dumps(config.to_dict(), indent=2)) def post_epoch_callback(self): - """Post epoch callback - - Clears the store and creates `num_rollouts` new episodes. - """ + """Clears the rollout store and creates `num_rollouts` new samples""" if self.log_rollouts: self.store.export_history(location=self.rollout_logging_dir) self.store.clear_history() @@ -246,15 +248,14 @@ def add_prompt_pipeline(self, pipeline: PromptPipeline): self.prompt_iterator = infinite_dataloader(prompt_dataloader) def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: - """Make experiences - + """ Takes `chunk_size` number of prompts from `prompt_iterator`, samples from the model and then computes the KL against a reference model. Finally it then appends PPOElements to trainer's `store`. Args: num_rollouts: Number of rollouts to generate - iter_count: Total number of updates run (i.e. number of updates run for all batches & epochs) + iter_count: Total number of updates for all batches & epochs """ logger.info("Collecting rollouts") tbar = logging.tqdm( diff --git a/trlx/trlx.py b/trlx/trlx.py index d724a9f24..a11286fc4 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -25,35 +25,48 @@ def train( # noqa: C901 stop_sequences: Optional[List[str]] = [], ): """ - Dispatches online, offline reinforcement training or supervised finetuning - depending on whether a reward function or a list of samples & rewards, or only list of samples is given + Runs online, offline reinforcement training or supervised finetuning depending on provided arguments. + `reward_fn` and `prompts` are required for online training, `samples` and `rewards` are required for offline training. Args: - model_path (Optional[str]): Path to either huggingface checkpoint or a local directory - config (Optional[TRLConfig]): TRLX configuration object - reward_fn (Optional[Callable[[List[str], List[str], List[str]], List[float]]]): - Function to rate batches of generated samples. Its arguments are - (`samples`, `prompts`, `outputs`) and the return is a list of `rewards` - dataset (List[Union[str, List[str]]], List[float]): + model_path (`Optional[str]`): + Path to either huggingface hub checkpoint or a local directory. + + config (`Optional[TRLConfig]`): + Training configuration object. + + reward_fn (`Optional[Callable[[List[str], List[str], List[str]], List[float]]]`): + A function to rate batches of generated samples. Its required arguments are + (`samples`, `prompts`, `outputs`) and the return is a list of scalar rewards per each sample in batch + + dataset (`List[Union[str, List[str]]], List[float]`): Lists of samples and rewards for offline training. (Use `samples` and `rewards` instead) - samples (List[Union[str, List[str]]]): + + samples (`List[Union[str, List[str]]]`): List of strings or a list of prompts (questions or environment states) and outputs which are meant to be optimized. In the latter case the following form is expected: (prompt_0: str, output_0: str, prompt_1: str, output_1: str ...). Giving a single string `s` for the sample is a shorthand for (`tokenizer.bos_token`, `s`) - rewards (List[float]): - List of real numbers measuring the goodness of each sample - prompts (`List[str]` or `List[Dict[str, Any]]`): Prompts to use for generations during online training. + + rewards (`List[float]`): + List of scalar rewards per each sample in `samples`. + + prompts (`Union[List[str], List[Dict[str, Any]]]`): + Prompts to use for generations during online training. If a dict is passed as prompt, it must have a required key `"prompt"`, all the extra keys would be passed along the generation for that prompt as a keyword argument to reward function. - eval_prompts (List[str] or `List[Dict[str, Any]]`): Prompts to use for periodical validation of training - metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]): + + eval_prompts (`Union[List[str], List[Dict[str, Any]]]`): + Prompts to use for periodical validation of training. + + metric_fn (`Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]`): Function to compute statistics on batches of generated samples. Its arguments are the same - as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys - as metric's name and values and lists of numeric values per each sample in batch - stop_sequences (Optional[List[str]]): + as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is a dictionary of mapping from + metric's name to a list of scalar values per each sample in batch. + + stop_sequences (`Optional[List[str]]`): String sequences to trim generations (both for generating of experience and evaluation) up to its - encounter in them. Generations will not contain them and also will also be right-stripped + encounter in them. Generations will not contain them and also will also be right-stripped. """ if config is None: warnings.warn(