Skip to content

Commit

Permalink
support parallel reward function
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingru committed Dec 7, 2023
1 parent 6838580 commit 2cf8af5
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,11 @@ def evaluate(self): # noqa: C901

if self.accelerator.is_main_process:
columns = ["prompt", "output"]
columns_data = [str_prompts, str_outputs]
if not self.config.train.reward_only_in_main_process:
columns_data = self.accelerator.gather_for_metrics(columns_data)

# gather should be invoked in every process, not just the main process
columns_data = [str_prompts, str_outputs]
if not self.config.train.reward_only_in_main_process:
columns_data = self.accelerator.gather_for_metrics(columns_data)

metadata, *xs = all_metadata
for k in metadata:
Expand All @@ -447,9 +449,11 @@ def evaluate(self): # noqa: C901
else:
rewards = torch.tensor(rewards, dtype=float)

# gather should be invoked in every process, not just the main process
if not self.config.train.reward_only_in_main_process:
rewards = self.accelerator.gather(rewards)

if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process:
rewards = self.accelerator.gather(rewards)
mean_reward = rewards.mean().item()

columns.append("reward")
Expand Down

0 comments on commit 2cf8af5

Please sign in to comment.