Skip to content

Commit

Permalink
Clean configs documentation (#1944)
Browse files Browse the repository at this point in the history
* Clean BCO

* Optional[int]

* fix sft config

* alignprop config

* upadte tempfile to work with output_dir

* clean kto config

* intro docstring

* style

* reward config

* orpo config

* warning in trainer, not in config

* cpo config

* ppo v2

* model config

* ddpo and per_device_train_batch_size (instead of (train_batch_size)

* rloo

* Online config

* tmp_dir in test_ddpo

* style

* remove to_dict and fix post-init

* batch size in test ddpo

* dpo

* style

* `Args` -> `Parameters`

* parameters

* ppo config

* dont overwrite world size

* style

* outputdir in test ppo

* output dir in ppo config

* revert non-core change (1/n)

* revert non-core changes (2/n)

* revert non-core change (3/n)

* uniform max_length

* fix uniform max_length

* beta uniform

* style

* link to `ConstantLengthDataset`

* uniform `dataset_num_proc`

* uniform `disable_dropout`

* `eval_packing` doc

* try latex and α in doc

* try title first

* doesn't work

* reorganize doc

* overview

* better latex

* is_encoder_decoder uniform

* proper ticks

* fix latex

* uniform generate_during_eval

* uniform truncation_mode

* ref_model_mixup_alpha

* ref_model_mixup_alpha and ref_model_sync_steps

* Uniform  `model_init_kwargs` and `ref_model_init_kwargs`

* rpo_alpha

* Update maximum length argument names in config files

* Update loss_type descriptions in config files

* Update max_target_length to max_completion_length in CPOConfig and CPOTrainer

* Update padding value in config files

* Update precompute_ref_log_probs flag documentation

* Fix typos and update comments in dpo_config.py and sft_config.py

* post init warning for `max_target_length`
  • Loading branch information
qgallouedec authored Sep 4, 2024
1 parent 7acb9c2 commit fc20db8
Show file tree
Hide file tree
Showing 20 changed files with 851 additions and 636 deletions.
68 changes: 35 additions & 33 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,46 @@
title: Understanding Logs
title: Get started
- sections:
- sections:
- local: trainer
title: Overview
- local: alignprop_trainer
title: AlignProp
- local: bco_trainer
title: BCO
- local: cpo_trainer
title: CPO
- local: ddpo_trainer
title: DDPO
- local: dpo_trainer
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: orpo_trainer
title: ORPO
- local: kto_trainer
title: KTO
- local: ppo_trainer
title: PPO
- local: ppov2_trainer
title: PPOv2
- local: rloo_trainer
title: RLOO
- local: sft_trainer
title: SFT
- local: iterative_sft_trainer
title: Iterative SFT
- local: reward_trainer
title: Reward Model
title: Trainers
- local: models
title: Model Classes
- local: trainer
title: Trainer Classes
- local: reward_trainer
title: Reward Model Training
- local: sft_trainer
title: Supervised Fine-Tuning
- local: ppo_trainer
title: PPO Trainer
- local: ppov2_trainer
title: PPOv2 Trainer
- local: rloo_trainer
title: RLOO Trainer
- local: best_of_n
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: online_dpo_trainer
title: Online DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: bco_trainer
title: BCO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: alignprop_trainer
title: AlignProp Trainer
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: callbacks
title: Callback Classes
- local: judges
title: Judge Classes
title: Judges
- local: callbacks
title: Callbacks
- local: text_environments
title: Text Environments
title: API
Expand Down
6 changes: 2 additions & 4 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_cpo(self):
max_length=256,
max_prompt_length=64,
max_completion_length=64,
max_target_length=64,
beta=0.5,
label_smoothing=0.5,
loss_type="hinge",
Expand All @@ -96,7 +95,6 @@ def test_cpo(self):
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.max_prompt_length, 64)
self.assertEqual(trainer.args.max_completion_length, 64)
self.assertEqual(trainer.args.max_target_length, 64)
self.assertEqual(trainer.args.beta, 0.5)
self.assertEqual(trainer.args.label_smoothing, 0.5)
self.assertEqual(trainer.args.loss_type, "hinge")
Expand Down Expand Up @@ -127,7 +125,7 @@ def test_dpo(self):
truncation_mode="keep_start",
max_length=256,
max_prompt_length=64,
max_target_length=64,
max_completion_length=64,
is_encoder_decoder=True,
disable_dropout=False,
# generate_during_eval=True, # ignore this one, it requires wandb
Expand Down Expand Up @@ -155,7 +153,7 @@ def test_dpo(self):
self.assertEqual(trainer.args.truncation_mode, "keep_start")
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.max_prompt_length, 64)
self.assertEqual(trainer.args.max_target_length, 64)
self.assertEqual(trainer.args.max_completion_length, 64)
self.assertEqual(trainer.args.is_encoder_decoder, True)
self.assertEqual(trainer.args.disable_dropout, False)
# self.assertEqual(trainer.args.generate_during_eval, True)
Expand Down
124 changes: 78 additions & 46 deletions trl/trainer/alignprop_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,117 @@
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Any, Dict, Literal, Optional, Tuple

from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available


@dataclass
class AlignPropConfig:
"""
Configuration class for AlignPropTrainer
r"""
Configuration class for the [`AlignPropTrainer`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
Name of this experiment (defaults to the file name without the extension).
run_name (`str`, *optional*, defaults to `""`):
Name of this run.
log_with (`Optional[Literal["wandb", "tensorboard"]]`, *optional*, defaults to `None`):
Log with either `"wandb"` or `"tensorboard"`. Check
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
log_image_freq (`int`, *optional*, defaults to `1`):
Frequency for logging images.
tracker_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the tracker (e.g., `wandb_project`).
accelerator_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator.
project_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
Name of project to use for tracking.
logdir (`str`, *optional*, defaults to `"logs"`):
Top-level logging directory for checkpoint saving.
num_epochs (`int`, *optional*, defaults to `100`):
Number of epochs to train.
save_freq (`int`, *optional*, defaults to `1`):
Number of epochs between saving model checkpoints.
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
Number of checkpoints to keep before overwriting old ones.
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
Mixed precision training.
allow_tf32 (`bool`, *optional*, defaults to `True`):
Allow `tf32` on Ampere GPUs.
resume_from (`str`, *optional*, defaults to `""`):
Path to resume training from a checkpoint.
sample_num_steps (`int`, *optional*, defaults to `50`):
Number of sampler inference steps.
sample_eta (`float`, *optional*, defaults to `1.0`):
Eta parameter for the DDIM sampler.
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
Classifier-free guidance weight.
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
Learning rate.
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
Beta1 for Adam optimizer.
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
Beta2 for Adam optimizer.
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
Weight decay for Adam optimizer.
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
Epsilon value for Adam optimizer.
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
Number of gradient accumulation steps.
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
Maximum gradient norm for gradient clipping.
negative_prompts (`Optional[str]`, *optional*, defaults to `None`):
Comma-separated list of prompts to use as negative examples.
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
If `True`, randomized truncation to different diffusion timesteps is used.
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
truncated_rand_backprop_minmax (`Tuple[int, int]`, *optional*, defaults to `(0, 50)`):
Range of diffusion timesteps for randomized truncated backpropagation.
"""

# common parameters
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
"""the name of this experiment (by default is the file name without the extension name)"""
run_name: Optional[str] = ""
"""Run name for wandb logging and checkpoint saving."""
run_name: str = ""
seed: int = 0
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
log_image_freq = 1
"""Logging Frequency for images"""
tracker_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. wandb_project)"""
accelerator_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
log_image_freq: int = 1
tracker_kwargs: Dict[str, Any] = field(default_factory=dict)
accelerator_kwargs: Dict[str, Any] = field(default_factory=dict)
project_kwargs: Dict[str, Any] = field(default_factory=dict)
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
logdir: str = "logs"
"""Top-level logging directory for checkpoint saving."""

# hyperparameters
num_epochs: int = 100
"""Number of epochs to train."""
save_freq: int = 1
"""Number of epochs between saving model checkpoints."""
num_checkpoint_limit: int = 5
"""Number of checkpoints to keep before overwriting old ones."""
mixed_precision: str = "fp16"
"""Mixed precision training."""
allow_tf32: bool = True
"""Allow tf32 on Ampere GPUs."""
resume_from: Optional[str] = ""
"""Resume training from a checkpoint."""
resume_from: str = ""
sample_num_steps: int = 50
"""Number of sampler inference steps."""
sample_eta: float = 1.0
"""Eta parameter for the DDIM sampler."""
sample_guidance_scale: float = 5.0
"""Classifier-free guidance weight."""
train_batch_size: int = 1
"""Batch size (per GPU!) to use for training."""
train_use_8bit_adam: bool = False
"""Whether to use the 8bit Adam optimizer from bitsandbytes."""
train_learning_rate: float = 1e-3
"""Learning rate."""
train_adam_beta1: float = 0.9
"""Adam beta1."""
train_adam_beta2: float = 0.999
"""Adam beta2."""
train_adam_weight_decay: float = 1e-4
"""Adam weight decay."""
train_adam_epsilon: float = 1e-8
"""Adam epsilon."""
train_gradient_accumulation_steps: int = 1
"""Number of gradient accumulation steps."""
train_max_grad_norm: float = 1.0
"""Maximum gradient norm for gradient clipping."""
negative_prompts: Optional[str] = ""
"""Comma-separated list of prompts to use as negative examples."""
negative_prompts: Optional[str] = None
truncated_backprop_rand: bool = True
"""Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps"""
truncated_backprop_timestep: int = 49
"""Absolute timestep to which the gradients are being backpropagated. If truncated_backprop_rand is False"""
truncated_rand_backprop_minmax: tuple = (0, 50)
"""Range of diffusion timesteps for randomized truncated backprop."""
truncated_rand_backprop_minmax: Tuple[int, int] = (0, 50)

def to_dict(self):
output_dict = {}
Expand Down
Loading

0 comments on commit fc20db8

Please sign in to comment.