diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 4f6277c6a..5dfa5b808 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -289,22 +289,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) - padded_samples = self.accelerator.pad_across_processes( - samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False - ) - padded_prompts = self.accelerator.pad_across_processes( - prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False - ) - metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"} + if self.config.train.reward_only_in_main_process: + padded_samples = self.accelerator.pad_across_processes( + samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + padded_prompts = self.accelerator.pad_across_processes( + prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) metadata = gather_dict(metadata) else: - gathered_samples = padded_samples - gathered_prompts = padded_prompts + gathered_samples = samples + gathered_prompts = prompt_tensors gathered_prompt_sizes = prompt_sizes if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process: