Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Nov 9, 2023
1 parent 411cf79 commit 5ed4f10
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def make_inputs_require_grad(module, input, output):
self.max_target_length = max_target_length
self.tokenizer = tokenizer
self.precompute_ref_log_probs = precompute_ref_log_probs

# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
Expand Down Expand Up @@ -390,7 +390,7 @@ def get_train_dataloader(self) -> DataLoader:

# tokenize the dataset and compute reference logps for training datasets
self.train_dataset = self.train_dataset.map(self.tokenize_batch_element)

if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
dataloader_params = {
"batch_size": self.args.per_device_train_batch_size,
Expand Down

0 comments on commit 5ed4f10

Please sign in to comment.