From 9b7764e1a8a5f82e437d1281e9dbc71cc47c2815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 21 Dec 2024 13:22:00 +0000 Subject: [PATCH] kto --- trl/trainer/kto_config.py | 135 ++++++++++++++++++++++++++++++++------ 1 file changed, 115 insertions(+), 20 deletions(-) diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index 7fe12a317a..5bf02f143b 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Literal, Optional +from dataclasses import dataclass, field +from typing import Any, Optional from transformers import TrainingArguments @@ -80,21 +80,116 @@ class KTOConfig(TrainingArguments): Whether to disable dropout in the model. """ - learning_rate: float = 1e-6 - max_length: Optional[int] = None - max_prompt_length: Optional[int] = None - max_completion_length: Optional[int] = None - beta: float = 0.1 - loss_type: Literal["kto", "apo_zero_unpaired"] = "kto" - desirable_weight: float = 1.0 - undesirable_weight: float = 1.0 - label_pad_token_id: int = -100 - padding_value: Optional[int] = None - truncation_mode: str = "keep_end" - generate_during_eval: bool = False - is_encoder_decoder: Optional[bool] = None - disable_dropout: bool = True - precompute_ref_log_probs: bool = False - model_init_kwargs: Optional[dict[str, Any]] = None - ref_model_init_kwargs: Optional[dict[str, Any]] = None - dataset_num_proc: Optional[int] = None + learning_rate: float = field( + default=1e-6, + metadata={ + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." + }, + ) + max_length: Optional[int] = field( + default=None, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " + "the reference model." + }, + ) + loss_type: str = field( + default="kto", + metadata={ + "help": "Type of loss to use.", + "choices": ["kto", "apo_zero_unpaired"], + }, + ) + desirable_weight: float = field( + default=1.0, + metadata={ + "help": "Desirable losses are weighed by this factor to counter unequal number of desirable and " + "undesirable pairs.", + }, + ) + undesirable_weight: float = field( + default=1.0, + metadata={ + "help": "Undesirable losses are weighed by this factor to counter unequal number of desirable and " + "undesirable pairs.", + }, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "If `True`, generates and logs completions from both the model and the reference model to W&B " + "during evaluation." + }, + ) + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " + "argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " + "This is useful when training without the reference model to reduce the total GPU memory needed." + }, + ) + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "reference model from a string." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + )