Skip to content

Commit

Permalink
dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 21, 2024
1 parent 835ca04 commit ed6954f
Showing 1 changed file with 215 additions and 51 deletions.
266 changes: 215 additions & 51 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Literal, Optional
from typing import Any, Optional

from transformers import TrainingArguments

Expand Down Expand Up @@ -67,7 +67,7 @@ class DPOConfig(TrainingArguments):
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
use_weighting (`bool`, *optional*, defaults to `False`):
Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`Optional[int]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -141,54 +141,218 @@ class DPOConfig(TrainingArguments):
τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
use_num_logits_to_keep (`bool`, *optional*, defaults to `False`):
If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be useful
for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios
when working with very long prompts where labels are -ignored (-100).
If `True`, only a specified number of logits are computed in the forward pass of CausalLM. This can be
useful for saving memory and speeding up training by not computing the logits for all tokens, especially in
scenarios when working with very long prompts where labels are ignored (-100).
[Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM)
"""

learning_rate: float = 1e-6
beta: float = 0.1
label_smoothing: float = 0.0
loss_type: Literal[
"sigmoid",
"hinge",
"ipo",
"exo_pair",
"nca_pair",
"robust",
"bco_pair",
"sppo_hard",
"aot",
"aot_pair",
"discopop",
"apo_zero",
"apo_down",
] = "sigmoid"
use_weighting: bool = False
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
generate_during_eval: bool = False
precompute_ref_log_probs: bool = False
precompute_ref_batch_size: Optional[int] = None
dataset_num_proc: Optional[int] = None
model_init_kwargs: Optional[dict[str, Any]] = None
ref_model_init_kwargs: Optional[dict[str, Any]] = None
model_adapter_name: Optional[str] = None
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
f_divergence_type: FDivergenceType = FDivergenceType.REVERSE_KL
f_alpha_divergence_coef: float = 1.0
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None
discopop_tau: float = 0.05
use_num_logits_to_keep: bool = False
learning_rate: float = field(
default=1e-6,
metadata={
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
"`transformers.TrainingArguments`."
},
)
beta: float = field(
default=0.1,
metadata={
"help": "Parameter controlling the deviation from the reference model. "
"Higher β means less deviation from the reference model."
},
)
label_smoothing: float = field(
default=0.0,
metadata={"help": "Label smoothing factor."},
)
loss_type: str = field(
default="sigmoid",
metadata={
"help": "Type of loss to use.",
"choices": [
"sigmoid",
"hinge",
"ipo",
"exo_pair",
"nca_pair",
"robust",
"bco_pair",
"sppo_hard",
"aot",
"aot_pair",
"discopop",
"apo_zero",
"apo_down",
],
},
)
use_weighting: bool = field(
default=False,
metadata={"help": "Whether to weight the loss as done in the WPO paper."},
)
label_pad_token_id: int = field(
default=-100,
metadata={
"help": "Label pad token id. This argument is required if you want to use the default data collator."
},
)
padding_value: Optional[int] = field(
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the prompt is too long. Possible values are "
"`keep_end` or `keep_start`. This argument is required if you want to use the "
"default data collator.",
"choices": ["keep_end", "keep_start"],
},
)
max_length: Optional[int] = field(
default=None,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
default=None,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: Optional[int] = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
is_encoder_decoder: Optional[bool] = field(
default=None,
metadata={
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the "
"`model` argument, you need to specify if the model returned by the callable is an encoder-decoder model."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."},
)
generate_during_eval: bool = field(
default=False,
metadata={
"help": "If `True`, generates and logs completions from both the model and the reference model "
"to W&B during evaluation."
},
)
precompute_ref_log_probs: bool = field(
default=False,
metadata={
"help": "Whether to precompute reference model log probabilities for training and evaluation datasets. "
"This is useful when training without the reference model to reduce the total GPU memory needed."
},
)
precompute_ref_batch_size: Optional[int] = field(
default=None,
metadata={
"help": "Batch size to use when precomputing reference model log probabilities. This can be set higher "
"than the training batch size to speed up preprocessing. If `None`, defaults to "
"`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation."
},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"model from a string."
},
)
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"reference model from a string."
},
)
model_adapter_name: Optional[str] = field(
default=None,
metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
)
ref_adapter_name: Optional[str] = field(
default=None,
metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
)
reference_free: bool = field(
default=False,
metadata={
"help": "If `True`, we ignore the _provided_ reference model and implicitly use a reference model that "
"assigns equal probability to all responses."
},
)
force_use_ref_model: bool = field(
default=False,
metadata={
"help": "In case one passes a PEFT model for the active model and you want to use a different model for "
"the ref_model, set this flag to `True`."
},
)
f_divergence_type: FDivergenceType = field(
default=FDivergenceType.REVERSE_KL,
metadata={
"help": "Type of f-divergence regularization function to compute divergence between policy and reference "
"model."
},
)
f_alpha_divergence_coef: float = field(
default=1.0,
metadata={"help": "α coefficient in the α-divergence u^-α regularization function for DPO loss."},
)
sync_ref_model: bool = field(
default=False,
metadata={
"help": "When set to `True`, the reference model is synchronized with the active model every "
"`ref_model_sync_steps` steps, using the `ref_model_mixup_alpha` parameter."
},
)
ref_model_mixup_alpha: float = field(
default=0.9,
metadata={
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
"previous reference policy during updates. The reference policy is updated according to the equation: "
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`"
},
)
ref_model_sync_steps: int = field(
default=64,
metadata={
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
"synchronized with the reference policy."
},
)
rpo_alpha: Optional[float] = field(
default=None,
metadata={
"help": "α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. "
"If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends "
"`rpo_alpha=1.0`."
},
)
discopop_tau: float = field(
default=0.05,
metadata={
"help": "τ/temperature parameter from the DiscoPOP paper, which controls the shape of log ratio modulated "
"loss. The paper recommends the default value `discopop_tau=0.05`."
},
)
use_num_logits_to_keep: bool = field(
default=False,
metadata={
"help": "If `True`, only a specified number of logits are computed in the forward pass of CausalLM. "
"This can be useful for saving memory and speeding up training by not computing the logits for all "
"tokens, especially in scenarios when working with very long prompts where labels are ignored (-100)."
},
)

0 comments on commit ed6954f

Please sign in to comment.