Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👨‍🍳 Clarify DPO data preparation #2512

Merged
merged 1 commit into from
Dec 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 53 additions & 48 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from datasets import Dataset, IterableDataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import (
Expand Down Expand Up @@ -436,53 +436,16 @@ def make_inputs_require_grad(module, input, output):
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed, and apply the chat template if needed
train_dataset = train_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
)
train_dataset = train_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to eval dataset",
)

# tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models)
fn_kwargs = {
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": self.is_encoder_decoder,
}
train_dataset = train_dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
fn_kwargs=fn_kwargs,
num_proc=self.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
fn_kwargs=fn_kwargs,
num_proc=self.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing eval dataset",
)
# Dataset preparation
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")

super().__init__(
model=model,
Expand Down Expand Up @@ -540,6 +503,48 @@ def make_inputs_require_grad(module, input, output):
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc

with PartialState().local_main_process_first():
# Extract prompt if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)

# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs)

# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["prompt", "chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)

return dataset

@staticmethod
def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):
"""
Expand Down
Loading