Skip to content

Commit

Permalink
Feat: Add support for APO-zero in KTOTrainer (#1952)
Browse files Browse the repository at this point in the history
* feat : add kto command

* feat : add support for apo loss in KTO Trainer

* feat : make kto script compatible with dpo-formatted datasets

* fix: lint data utils

* add loss_type in kto test

* fix: data utils docstrings

* fix: add dataset reformat test

* fix: lint tests

* fix: only reference kl_logps if needed

---------

Co-authored-by: Karel D'Oosterlinck <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: lewtun <[email protected]>
  • Loading branch information
4 people authored Sep 4, 2024
1 parent 6840380 commit 7acb9c2
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 135 deletions.
13 changes: 10 additions & 3 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_reformat_dpo_to_kto, setup_chat_format


# Define and parse arguments.
Expand Down Expand Up @@ -97,10 +97,17 @@ class ScriptArguments:
# Load the dataset
dataset = load_dataset(script_args.dataset_name)

# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_reformat_dpo_to_kto(dataset, num_proc=kto_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
if isinstance(example["completion"], str):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
else:
example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
return example

# Compute that only on the main process for faster data processing.
Expand Down
71 changes: 71 additions & 0 deletions tests/test_dataset_reformat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from datasets import Dataset, DatasetDict

from trl.data_utils import maybe_reformat_dpo_to_kto


class MaybeReformatDPOToKTOTester(unittest.TestCase):
def setUp(self):
# Create a sample DPO-formatted dataset for testing
self.dpo_data = {
"prompt": ["What is AI?", "Define machine learning."],
"chosen": ["AI is artificial intelligence.", "Machine learning is a subset of AI."],
"rejected": ["AI is a computer.", "Machine learning is a program."],
}
self.dpo_dataset = DatasetDict({"train": Dataset.from_dict(self.dpo_data)})

# Create a sample KTO-formatted dataset for testing
self.kto_data = {
"prompt": ["What is AI?", "Define machine learning.", "What is AI?", "Define machine learning."],
"completion": [
"AI is artificial intelligence.",
"Machine learning is a subset of AI.",
"AI is a computer.",
"Machine learning is a program.",
],
"label": [True, True, False, False],
}
self.kto_dataset = DatasetDict({"train": Dataset.from_dict(self.kto_data)})

def test_dpo_to_kto_conversion(self):
# Test that a DPO-formatted dataset is correctly reformatted to KTO format
reformatted_dataset = maybe_reformat_dpo_to_kto(self.dpo_dataset)
self.assertEqual(
reformatted_dataset["train"].to_dict(),
self.kto_dataset["train"].to_dict(),
"The DPO-formatted dataset was not correctly reformatted to KTO format.",
)

def test_already_kto_format(self):
# Test that a KTO-formatted dataset remains unchanged
reformatted_dataset = maybe_reformat_dpo_to_kto(self.kto_dataset)
self.assertEqual(
reformatted_dataset["train"].to_dict(),
self.kto_dataset["train"].to_dict(),
"The KTO-formatted dataset should remain unchanged.",
)

def test_invalid_format(self):
# Test that a dataset with an incompatible format raises a ValueError
invalid_data = {
"input": ["What is AI?", "Define machine learning."],
"output": ["AI is artificial intelligence.", "Machine learning is a subset of AI."],
}
invalid_dataset = DatasetDict({"train": Dataset.from_dict(invalid_data)})

with self.assertRaises(ValueError):
maybe_reformat_dpo_to_kto(invalid_dataset)
17 changes: 10 additions & 7 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ def _init_dummy_dataset(self):

@parameterized.expand(
[
["gpt2", True, True],
["gpt2", True, False],
# ["t5", True],
["gpt2", False, True],
["gpt2", False, False],
# ["t5", False],
["gpt2", "kto", True, True],
["gpt2", "kto", True, False],
["gpt2", "kto", False, True],
["gpt2", "kto", False, False],
["gpt2", "apo_zero_unpaired", True, True],
["gpt2", "apo_zero_unpaired", True, False],
["gpt2", "apo_zero_unpaired", False, True],
["gpt2", "apo_zero_unpaired", False, False],
]
)
def test_kto_trainer(self, name, pre_compute, eval_dataset):
def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
Expand All @@ -95,6 +97,7 @@ def test_kto_trainer(self, name, pre_compute, eval_dataset):
eval_strategy="steps",
beta=0.1,
precompute_ref_log_probs=pre_compute,
loss_type=loss_type,
report_to="none",
)

Expand Down
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"MultitaskPromptTuningConfig",
"MultitaskPromptTuningInit",
],
"data_utils": ["maybe_reformat_dpo_to_kto"],
}

try:
Expand Down Expand Up @@ -162,6 +163,7 @@
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
from .data_utils import maybe_reformat_dpo_to_kto

try:
if not is_diffusers_available():
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from rich.console import Console


SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto"]


def main():
Expand Down
74 changes: 74 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

from datasets import DatasetDict


def _reformat_row_dpo_to_kto(row: dict):
"""Turn a DPO-formatted dataset row into two KTO-formatted rows."""

chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])}
rejected_row = {
"prompt": row["prompt"],
"completion": row["rejected"],
"label": [False] * len(row["chosen"]),
}
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()}
return new_rows


def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None):
"""
Reformat a dataset from the DPO format to the KTO format if necessary.
This function checks whether the input dataset is already in the KTO format (containing "prompt", "completion", and "label" fields).
If the dataset is in DPO format (with "prompt", "chosen", and "rejected" fields), it converts it to KTO format by:
- Removing any unnecessary columns.
- Reformatting each row to create a unified format suitable for KTO training.
Args:
dataset (DatasetDict): The dataset to potentially reformat.
num_proc (int, optional): The number of processes to use for multiprocessing during dataset transformation. Defaults to None.
Returns:
DatasetDict: The reformatted dataset, if conversion was needed; otherwise, the original dataset.
Raises:
ValueError: If the dataset format is not compatible with KTO or DPO.
"""
keys = list(dataset["train"].features.keys())

# check if the dataset is in the KTO format or needs to be reformatted
if "prompt" in keys and "completion" in keys and "label" in keys:
return dataset
elif "prompt" in keys and "rejected" in keys and "chosen" in keys:
# remove unnecessary fields
keys_to_remove = deepcopy(keys)
keys_to_remove.remove("prompt")
keys_to_remove.remove("chosen")
keys_to_remove.remove("rejected")
dataset = dataset.remove_columns(keys_to_remove)

# turn each DPO-formatted row into two KTO-formatted rows.
dataset = dataset.map(
_reformat_row_dpo_to_kto,
num_proc=num_proc,
batched=True,
remove_columns=["chosen", "rejected"],
desc="Reformatting Dataset from DPO format to KTO format.",
)
return dataset
else:
raise ValueError("Dataset format not compatible with KTO.")
11 changes: 10 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, Literal, Optional

from transformers import TrainingArguments

Expand All @@ -27,6 +27,11 @@ class KTOConfig(TrainingArguments):
command line.
Parameters:
loss_type (`str`, *optional*, defaults to `"kto"`):
The type of unpaired loss to use. Possible values are:
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments):
Number of processes to use for processing the datasets.
"""

loss_type: Literal[
"kto",
"apo_zero_unpaired",
] = "kto"
max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
max_prompt_length: Optional[int] = None
Expand Down
Loading

0 comments on commit 7acb9c2

Please sign in to comment.