Skip to content

Commit

Permalink
support parallel reward function
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingru committed Nov 7, 2023
1 parent 43ea9f1 commit 4c38538
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,22 +289,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
device = samples.device

prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)

metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}

if self.config.train.reward_only_in_main_process:
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict(metadata)
else:
gathered_samples = padded_samples
gathered_prompts = padded_prompts
gathered_samples = samples
gathered_prompts = prompt_tensors
gathered_prompt_sizes = prompt_sizes

if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
Expand Down

0 comments on commit 4c38538

Please sign in to comment.