diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3c4c7771b2..04b0583237 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -30,7 +30,7 @@ import transformers from accelerate import PartialState from accelerate.utils import is_deepspeed_available, tqdm -from datasets import Dataset +from datasets import Dataset, IterableDataset from packaging import version from torch.utils.data import DataLoader from transformers import ( @@ -436,53 +436,16 @@ def make_inputs_require_grad(module, input, output): # that the warning has already been issued. model.warnings_issued["estimate_tokens"] = True - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().local_main_process_first(): - # Extract the prompt if needed, and apply the chat template if needed - train_dataset = train_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" - ) - train_dataset = train_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - desc="Applying chat template to train dataset", - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" - ) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - desc="Applying chat template to eval dataset", - ) - - # tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models) - fn_kwargs = { - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": self.is_encoder_decoder, - } - train_dataset = train_dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - fn_kwargs=fn_kwargs, - num_proc=self.dataset_num_proc, - writer_batch_size=10, - desc="Tokenizing train dataset", - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - fn_kwargs=fn_kwargs, - num_proc=self.dataset_num_proc, - writer_batch_size=10, - desc="Tokenizing eval dataset", - ) + # Dataset preparation + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") super().__init__( model=model, @@ -540,6 +503,48 @@ def make_inputs_require_grad(module, input, output): if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator) + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: DPOConfig, + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Build the kwargs for the `map` function + map_kwargs = {"writer_batch_size": 10} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().local_main_process_first(): + # Extract prompt if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" + dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + + # Apply the chat template if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + dataset = dataset.map( + self.tokenize_row if not self.is_vision_model else self.process_row, + remove_columns=["prompt", "chosen", "rejected"], + fn_kwargs={ + "processing_class": processing_class, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) + "add_special_tokens": False, + }, + **map_kwargs, + ) + + return dataset + @staticmethod def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): """