diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index e7abe4bb6e..aa14297202 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Any, Literal, Optional +from typing import Any, Optional from transformers import TrainingArguments @@ -67,7 +67,7 @@ class DPOConfig(TrainingArguments): - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. use_weighting (`bool`, *optional*, defaults to `False`): - Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. + Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. label_pad_token_id (`int`, *optional*, defaults to `-100`): Label pad token id. This argument is required if you want to use the default data collator. padding_value (`Optional[int]`, *optional*, defaults to `None`): @@ -141,54 +141,218 @@ class DPOConfig(TrainingArguments): τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be useful - for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios - when working with very long prompts where labels are -ignored (-100). + If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be + useful for saving memory and speeding up training by not computing the logits for all tokens, especially in + scenarios when working with very long prompts where labels are ignored (-100). [Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM) """ - learning_rate: float = 1e-6 - beta: float = 0.1 - label_smoothing: float = 0.0 - loss_type: Literal[ - "sigmoid", - "hinge", - "ipo", - "exo_pair", - "nca_pair", - "robust", - "bco_pair", - "sppo_hard", - "aot", - "aot_pair", - "discopop", - "apo_zero", - "apo_down", - ] = "sigmoid" - use_weighting: bool = False - label_pad_token_id: int = -100 - padding_value: Optional[int] = None - truncation_mode: str = "keep_end" - max_length: Optional[int] = None - max_prompt_length: Optional[int] = None - max_completion_length: Optional[int] = None - is_encoder_decoder: Optional[bool] = None - disable_dropout: bool = True - generate_during_eval: bool = False - precompute_ref_log_probs: bool = False - precompute_ref_batch_size: Optional[int] = None - dataset_num_proc: Optional[int] = None - model_init_kwargs: Optional[dict[str, Any]] = None - ref_model_init_kwargs: Optional[dict[str, Any]] = None - model_adapter_name: Optional[str] = None - ref_adapter_name: Optional[str] = None - reference_free: bool = False - force_use_ref_model: bool = False - f_divergence_type: FDivergenceType = FDivergenceType.REVERSE_KL - f_alpha_divergence_coef: float = 1.0 - sync_ref_model: bool = False - ref_model_mixup_alpha: float = 0.9 - ref_model_sync_steps: int = 64 - rpo_alpha: Optional[float] = None - discopop_tau: float = 0.05 - use_num_logits_to_keep: bool = False + learning_rate: float = field( + default=1e-6, + metadata={ + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. " + "Higher β means less deviation from the reference model." + }, + ) + label_smoothing: float = field( + default=0.0, + metadata={"help": "Label smoothing factor."}, + ) + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": [ + "sigmoid", + "hinge", + "ipo", + "exo_pair", + "nca_pair", + "robust", + "bco_pair", + "sppo_hard", + "aot", + "aot_pair", + "discopop", + "apo_zero", + "apo_down", + ], + }, + ) + use_weighting: bool = field( + default=False, + metadata={"help": "Whether to weight the loss as done in the WPO paper."}, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long. Possible values are " + "`keep_end` or `keep_start`. This argument is required if you want to use the " + "default data collator.", + "choices": ["keep_end", "keep_start"], + }, + ) + max_length: Optional[int] = field( + default=None, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " + "`model` argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "If `True`, generates and logs completions from both the model and the reference model " + "to W&B during evaluation." + }, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " + "This is useful when training without the reference model to reduce the total GPU memory needed." + }, + ) + precompute_ref_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "Batch size to use when precomputing reference model log probabilities. This can be set higher " + "than the training batch size to speed up preprocessing. If `None`, defaults to " + "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "model from a string." + }, + ) + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "reference model from a string." + }, + ) + model_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, + ) + ref_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + reference_free: bool = field( + default=False, + metadata={ + "help": "If `True`, we ignore the _provided_ reference model and implicitly use a reference model that " + "assigns equal probability to all responses." + }, + ) + force_use_ref_model: bool = field( + default=False, + metadata={ + "help": "In case one passes a PEFT model for the active model and you want to use a different model for " + "the ref_model, set this flag to `True`." + }, + ) + f_divergence_type: FDivergenceType = field( + default=FDivergenceType.REVERSE_KL, + metadata={ + "help": "Type of f-divergence regularization function to compute divergence between policy and reference " + "model." + }, + ) + f_alpha_divergence_coef: float = field( + default=1.0, + metadata={"help": "α coefficient in the α-divergence u^-α regularization function for DPO loss."}, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "When set to `True`, the reference model is synchronized with the active model every " + "`ref_model_sync_steps` steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.9, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`" + }, + ) + ref_model_sync_steps: int = field( + default=64, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy." + }, + ) + rpo_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. " + "If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends " + "`rpo_alpha=1.0`." + }, + ) + discopop_tau: float = field( + default=0.05, + metadata={ + "help": "τ/temperature parameter from the DiscoPOP paper, which controls the shape of log ratio modulated " + "loss. The paper recommends the default value `discopop_tau=0.05`." + }, + ) + use_num_logits_to_keep: bool = field( + default=False, + metadata={ + "help": "If `True`, only a specified number of logits are computed in the forward pass of CausalLM. " + "This can be useful for saving memory and speeding up training by not computing the logits for all " + "tokens, especially in scenarios when working with very long prompts where labels are ignored (-100)." + }, + )