Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval dataset issue in DPOTrainer when precompute_ref_log_probs=True and ref_model=None #1107

Closed
Sanster opened this issue Dec 19, 2023 · 3 comments · Fixed by #1125
Closed
Assignees
Labels
🏋 DPO Related to DPO

Comments

@Sanster
Copy link

Sanster commented Dec 19, 2023

when precompute_ref_log_probs=True, reference_chosen_logps and reference_rejected_logps was not saved to self.eval_dataset. When ref_model=None, subsequent evaluations will use self.model to recalculate, resulting in eval/acc
is always zero (because the policy and reference are using the same model).

def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:

Perhaps it should be modified like this:

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
            dataloader_params = {
                "batch_size": self.args.per_device_eval_batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }

            # prepare dataloader
            data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

            reference_chosen_logps = []
            reference_rejected_logps = []
            for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
                reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
                reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
                    (reference_chosen_logp, reference_rejected_logp)
                )
                reference_chosen_logps.append(reference_chosen_logp.cpu())
                reference_rejected_logps.append(reference_rejected_logp.cpu())

            all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
            all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()

            eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
            eval_dataset = eval_dataset.add_column(
                name="reference_rejected_logps", column=all_reference_rejected_logps
            )
            #### Save calculated reference_chosen_logps and reference_rejected_logps #####
            if self.eval_dataset is not None:
                self.eval_dataset = eval_dataset
            self._precomputed_eval_ref_log_probs = True

        return super().get_eval_dataloader(eval_dataset=eval_dataset)
@lvwerra lvwerra added the 🏋 DPO Related to DPO label Dec 21, 2023
@lvwerra
Copy link
Member

lvwerra commented Dec 21, 2023

tagging @kashif here :)

@kashif kashif self-assigned this Dec 21, 2023
@kashif
Copy link
Collaborator

kashif commented Dec 21, 2023

@Sanster so i had assumed super().get_eval_dataloader(eval_dataset=eval_dataset) would then set the dataset... so you are saying that is not the case?

ah no I see its because in the trainer we use the self.eval_dataset right?

@kashif
Copy link
Collaborator

kashif commented Dec 21, 2023

great catch @Sanster thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants