Skip to content

Commit

Permalink
kto
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 21, 2024
1 parent 64543ca commit 9b7764e
Showing 1 changed file with 115 additions and 20 deletions.
135 changes: 115 additions & 20 deletions trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Optional

from transformers import TrainingArguments

Expand Down Expand Up @@ -80,21 +80,116 @@ class KTOConfig(TrainingArguments):
Whether to disable dropout in the model.
"""

learning_rate: float = 1e-6
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
beta: float = 0.1
loss_type: Literal["kto", "apo_zero_unpaired"] = "kto"
desirable_weight: float = 1.0
undesirable_weight: float = 1.0
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[dict[str, Any]] = None
ref_model_init_kwargs: Optional[dict[str, Any]] = None
dataset_num_proc: Optional[int] = None
learning_rate: float = field(
default=1e-6,
metadata={
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
"`transformers.TrainingArguments`."
},
)
max_length: Optional[int] = field(
default=None,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
default=None,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: Optional[int] = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
beta: float = field(
default=0.1,
metadata={
"help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
"the reference model."
},
)
loss_type: str = field(
default="kto",
metadata={
"help": "Type of loss to use.",
"choices": ["kto", "apo_zero_unpaired"],
},
)
desirable_weight: float = field(
default=1.0,
metadata={
"help": "Desirable losses are weighed by this factor to counter unequal number of desirable and "
"undesirable pairs.",
},
)
undesirable_weight: float = field(
default=1.0,
metadata={
"help": "Undesirable losses are weighed by this factor to counter unequal number of desirable and "
"undesirable pairs.",
},
)
label_pad_token_id: int = field(
default=-100,
metadata={
"help": "Label pad token id. This argument is required if you want to use the default data collator."
},
)
padding_value: Optional[int] = field(
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the prompt is too long.",
"choices": ["keep_end", "keep_start"],
},
)
generate_during_eval: bool = field(
default=False,
metadata={
"help": "If `True`, generates and logs completions from both the model and the reference model to W&B "
"during evaluation."
},
)
is_encoder_decoder: Optional[bool] = field(
default=None,
metadata={
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
"argument, you need to specify if the model returned by the callable is an encoder-decoder model."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropout in the model."},
)
precompute_ref_log_probs: bool = field(
default=False,
metadata={
"help": "Whether to precompute reference model log probabilities for training and evaluation datasets. "
"This is useful when training without the reference model to reduce the total GPU memory needed."
},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
"from a string."
},
)
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"reference model from a string."
},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

0 comments on commit 9b7764e

Please sign in to comment.