From 891edd2f25438445c12f10156accda162d22c13b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Dec 2024 20:41:58 +0000 Subject: [PATCH] compute loss instead of training step --- trl/trainer/online_dpo_trainer.py | 62 +++++++++++-------------------- 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index c014ce1e13..04e9ce566f 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -39,12 +39,10 @@ ProcessorMixin, Trainer, TrainerCallback, - is_apex_available, is_wandb_available, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker -from transformers.training_args import OptimizerNames -from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging +from transformers.utils import is_peft_available, logging from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template @@ -56,7 +54,6 @@ SIMPLE_CHAT_TEMPLATE, DPODataCollatorWithPadding, disable_dropout_in_model, - empty_cache, generate_model_card, get_reward, prepare_deepspeed, @@ -67,17 +64,6 @@ if is_peft_available(): from peft import PeftModel, get_peft_model -if is_apex_available(): - from apex import amp - - -if is_sagemaker_mp_enabled(): - from smdistributed.modelparallel import __version__ as SMP_VERSION - - IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - -else: - IS_SAGEMAKER_MP_POST_1_10 = False if is_wandb_available(): import wandb @@ -391,11 +377,13 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None return self.accelerator.prepare(eval_dataloader) - def training_step( - self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + num_items_in_batch: Optional[int] = None, ) -> torch.Tensor: - model.train() - # Apply chat template and tokenize the input. # We do this on-the-fly to enable the use of reward models and policies with different tokenizers / chat templates. batch_size = len(next(iter(inputs.values()))) @@ -579,28 +567,7 @@ def training_step( self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) self.stats["beta"].append(self.beta) - if ( - self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - empty_cache() - - kwargs = {} - - # For LOMO optimizers you need to explicitly use the learnign rate - if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - if self.use_apex: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps + return (loss, None) if return_outputs else loss # Same as Trainer._maybe_log_save_evaluate but log our metrics # start_time defaults to None to allow compatibility with transformers<=4.46 @@ -645,6 +612,19 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return (loss, None, None) + # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions. # This can be removed once the minimum transformers version is updated to 4.47. # Refer to https://github.com/huggingface/trl/pull/2288 for more details.