Skip to content

Commit

Permalink
clarify dpo data prep
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 21, 2024
1 parent b668048 commit 3fa096e
Showing 1 changed file with 53 additions and 48 deletions.
101 changes: 53 additions & 48 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from datasets import Dataset, IterableDataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import (
Expand Down Expand Up @@ -436,53 +436,16 @@ def make_inputs_require_grad(module, input, output):
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed, and apply the chat template if needed
train_dataset = train_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
)
train_dataset = train_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to eval dataset",
)

# tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models)
fn_kwargs = {
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": self.is_encoder_decoder,
}
train_dataset = train_dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
fn_kwargs=fn_kwargs,
num_proc=self.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
fn_kwargs=fn_kwargs,
num_proc=self.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing eval dataset",
)
# Dataset preparation
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")

super().__init__(
model=model,
Expand Down Expand Up @@ -540,6 +503,48 @@ def make_inputs_require_grad(module, input, output):
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc

with PartialState().local_main_process_first():
# Extract prompt if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)

# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs)

# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["prompt", "chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)

return dataset

@staticmethod
def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):
"""
Expand Down

0 comments on commit 3fa096e

Please sign in to comment.