diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 1e6e8e67ad..ea9c916d76 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -350,6 +350,40 @@ def test_dpo_trainer_with_ref_model_is_model(self): train_dataset=dummy_dataset["train"], ) + def test_precompute_ref_batch_size(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + precompute_ref_log_probs=True, + precompute_ref_batch_size=4, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_peft def test_dpo_trainer_without_providing_ref_model_with_lora(self): from peft import LoraConfig diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index ea4a176aa1..8a6e507dc1 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -94,6 +94,10 @@ class DPOConfig(TrainingArguments): precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): Whether to precompute reference model log probabilities for training and evaluation datasets. This is useful when training without the reference model to reduce the total GPU memory needed. + precompute_ref_batch_size (`Optional[int]`, *optional*, defaults to `None`): + Batch size to use when precomputing reference model log probabilities. This can be set higher than the + training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for + training and `per_device_eval_batch_size` for evaluation. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. model_init_kwargs (`Optional[dict[str, Any]]`, *optional*, defaults to `None`): @@ -173,6 +177,7 @@ class DPOConfig(TrainingArguments): disable_dropout: bool = True generate_during_eval: bool = False precompute_ref_log_probs: bool = False + precompute_ref_batch_size: Optional[int] = None dataset_num_proc: Optional[int] = None model_init_kwargs: Optional[dict[str, Any]] = None ref_model_init_kwargs: Optional[dict[str, Any]] = None diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c1f2776511..ea3f24a39d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -684,8 +684,9 @@ def get_train_dataloader(self) -> DataLoader: """ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size dataloader_params = { - "batch_size": self.args.per_device_train_batch_size, + "batch_size": batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, @@ -737,8 +738,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size dataloader_params = { - "batch_size": self.args.per_device_eval_batch_size, + "batch_size": batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory,