From 2cf8af54a4424f6494a9fe1f437d24364d6d697f Mon Sep 17 00:00:00 2001 From: Jingru Date: Thu, 7 Dec 2023 07:54:07 +0000 Subject: [PATCH] support parallel reward function --- trlx/trainer/accelerate_base_trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index bebd85d2..3a845642 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -421,9 +421,11 @@ def evaluate(self): # noqa: C901 if self.accelerator.is_main_process: columns = ["prompt", "output"] - columns_data = [str_prompts, str_outputs] - if not self.config.train.reward_only_in_main_process: - columns_data = self.accelerator.gather_for_metrics(columns_data) + + # gather should be invoked in every process, not just the main process + columns_data = [str_prompts, str_outputs] + if not self.config.train.reward_only_in_main_process: + columns_data = self.accelerator.gather_for_metrics(columns_data) metadata, *xs = all_metadata for k in metadata: @@ -447,9 +449,11 @@ def evaluate(self): # noqa: C901 else: rewards = torch.tensor(rewards, dtype=float) + # gather should be invoked in every process, not just the main process + if not self.config.train.reward_only_in_main_process: + rewards = self.accelerator.gather(rewards) + if self.accelerator.is_main_process: - if not self.config.train.reward_only_in_main_process: - rewards = self.accelerator.gather(rewards) mean_reward = rewards.mean().item() columns.append("reward")