generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: Add support for APO-zero in KTOTrainer (#1952)
* feat : add kto command * feat : add support for apo loss in KTO Trainer * feat : make kto script compatible with dpo-formatted datasets * fix: lint data utils * add loss_type in kto test * fix: data utils docstrings * fix: add dataset reformat test * fix: lint tests * fix: only reference kl_logps if needed --------- Co-authored-by: Karel D'Oosterlinck <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: lewtun <[email protected]>
- Loading branch information
1 parent
6840380
commit 7acb9c2
Showing
8 changed files
with
340 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
from datasets import Dataset, DatasetDict | ||
|
||
from trl.data_utils import maybe_reformat_dpo_to_kto | ||
|
||
|
||
class MaybeReformatDPOToKTOTester(unittest.TestCase): | ||
def setUp(self): | ||
# Create a sample DPO-formatted dataset for testing | ||
self.dpo_data = { | ||
"prompt": ["What is AI?", "Define machine learning."], | ||
"chosen": ["AI is artificial intelligence.", "Machine learning is a subset of AI."], | ||
"rejected": ["AI is a computer.", "Machine learning is a program."], | ||
} | ||
self.dpo_dataset = DatasetDict({"train": Dataset.from_dict(self.dpo_data)}) | ||
|
||
# Create a sample KTO-formatted dataset for testing | ||
self.kto_data = { | ||
"prompt": ["What is AI?", "Define machine learning.", "What is AI?", "Define machine learning."], | ||
"completion": [ | ||
"AI is artificial intelligence.", | ||
"Machine learning is a subset of AI.", | ||
"AI is a computer.", | ||
"Machine learning is a program.", | ||
], | ||
"label": [True, True, False, False], | ||
} | ||
self.kto_dataset = DatasetDict({"train": Dataset.from_dict(self.kto_data)}) | ||
|
||
def test_dpo_to_kto_conversion(self): | ||
# Test that a DPO-formatted dataset is correctly reformatted to KTO format | ||
reformatted_dataset = maybe_reformat_dpo_to_kto(self.dpo_dataset) | ||
self.assertEqual( | ||
reformatted_dataset["train"].to_dict(), | ||
self.kto_dataset["train"].to_dict(), | ||
"The DPO-formatted dataset was not correctly reformatted to KTO format.", | ||
) | ||
|
||
def test_already_kto_format(self): | ||
# Test that a KTO-formatted dataset remains unchanged | ||
reformatted_dataset = maybe_reformat_dpo_to_kto(self.kto_dataset) | ||
self.assertEqual( | ||
reformatted_dataset["train"].to_dict(), | ||
self.kto_dataset["train"].to_dict(), | ||
"The KTO-formatted dataset should remain unchanged.", | ||
) | ||
|
||
def test_invalid_format(self): | ||
# Test that a dataset with an incompatible format raises a ValueError | ||
invalid_data = { | ||
"input": ["What is AI?", "Define machine learning."], | ||
"output": ["AI is artificial intelligence.", "Machine learning is a subset of AI."], | ||
} | ||
invalid_dataset = DatasetDict({"train": Dataset.from_dict(invalid_data)}) | ||
|
||
with self.assertRaises(ValueError): | ||
maybe_reformat_dpo_to_kto(invalid_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2022 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from copy import deepcopy | ||
|
||
from datasets import DatasetDict | ||
|
||
|
||
def _reformat_row_dpo_to_kto(row: dict): | ||
"""Turn a DPO-formatted dataset row into two KTO-formatted rows.""" | ||
|
||
chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])} | ||
rejected_row = { | ||
"prompt": row["prompt"], | ||
"completion": row["rejected"], | ||
"label": [False] * len(row["chosen"]), | ||
} | ||
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()} | ||
return new_rows | ||
|
||
|
||
def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None): | ||
""" | ||
Reformat a dataset from the DPO format to the KTO format if necessary. | ||
This function checks whether the input dataset is already in the KTO format (containing "prompt", "completion", and "label" fields). | ||
If the dataset is in DPO format (with "prompt", "chosen", and "rejected" fields), it converts it to KTO format by: | ||
- Removing any unnecessary columns. | ||
- Reformatting each row to create a unified format suitable for KTO training. | ||
Args: | ||
dataset (DatasetDict): The dataset to potentially reformat. | ||
num_proc (int, optional): The number of processes to use for multiprocessing during dataset transformation. Defaults to None. | ||
Returns: | ||
DatasetDict: The reformatted dataset, if conversion was needed; otherwise, the original dataset. | ||
Raises: | ||
ValueError: If the dataset format is not compatible with KTO or DPO. | ||
""" | ||
keys = list(dataset["train"].features.keys()) | ||
|
||
# check if the dataset is in the KTO format or needs to be reformatted | ||
if "prompt" in keys and "completion" in keys and "label" in keys: | ||
return dataset | ||
elif "prompt" in keys and "rejected" in keys and "chosen" in keys: | ||
# remove unnecessary fields | ||
keys_to_remove = deepcopy(keys) | ||
keys_to_remove.remove("prompt") | ||
keys_to_remove.remove("chosen") | ||
keys_to_remove.remove("rejected") | ||
dataset = dataset.remove_columns(keys_to_remove) | ||
|
||
# turn each DPO-formatted row into two KTO-formatted rows. | ||
dataset = dataset.map( | ||
_reformat_row_dpo_to_kto, | ||
num_proc=num_proc, | ||
batched=True, | ||
remove_columns=["chosen", "rejected"], | ||
desc="Reformatting Dataset from DPO format to KTO format.", | ||
) | ||
return dataset | ||
else: | ||
raise ValueError("Dataset format not compatible with KTO.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.