Skip to content

Commit

Permalink
support parallel reward function
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingru committed Oct 31, 2023
1 parent bcbcdac commit 43ea9f1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
2 changes: 2 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class TrainConfig:

minibatch_size: Optional[int] = None

reward_only_in_main_process: bool = True

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
24 changes: 14 additions & 10 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,23 @@ def evaluate(self): # noqa: C901
if self.config.model.model_arch_type == "seq2seq":
samples = samples[:, 1:].contiguous()

prompt_sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(len(prompts.input_ids))
prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
self.accelerator.pad_across_processes(
[prompts.input_ids, samples, prompt_sizes.to(samples.device)],
dim=1,
pad_index=self.tokenizer.pad_token_id,
)
prompt_sizes = torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(
len(prompts.input_ids)
)
if self.config.train.reward_only_in_main_process:
prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
self.accelerator.pad_across_processes(
[prompts.input_ids, samples, prompt_sizes],
dim=1,
pad_index=self.tokenizer.pad_token_id,
)
)
metadata = gather_dict(metadata, self.accelerator.gradient_state)
else:
prompts = prompts.input_ids
all_samples.extend(samples.tolist())
all_prompts.extend(prompts.tolist())
all_prompt_sizes.extend(prompt_sizes.tolist())

metadata = gather_dict(metadata, self.accelerator.gradient_state)
all_metadata.append(metadata)

desc = [
Expand All @@ -410,7 +414,7 @@ def evaluate(self): # noqa: C901

stats["time/generate"] = time() - generate_time

if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes)

columns = ["prompt", "output"]
Expand Down
39 changes: 25 additions & 14 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"})

if self.accelerator.is_main_process:
metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}
if self.config.train.reward_only_in_main_process:
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict(metadata)
else:
gathered_samples = padded_samples
gathered_prompts = padded_prompts
gathered_prompt_sizes = prompt_sizes

if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
all_str_samples, all_str_prompts, all_str_outputs = self.decode(
gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True
)
Expand All @@ -316,9 +323,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
**metadata,
)
all_scores = [
torch.tensor(score, dtype=torch.float, device=device).view(
-1,
)
score.view(-1)
if isinstance(score, torch.Tensor)
else torch.tensor(score, dtype=torch.float, device=device).view(-1)
for score in all_scores
]
# Pad 0 reward on the ends
Expand All @@ -327,17 +334,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

stats["time/rollout_score"] = time() - rollout_score_time

all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
if self.config.train.reward_only_in_main_process:
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
else:
all_scores = None
max_len = torch.tensor(0, dtype=torch.long, device=device)

if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores)
if self.config.train.reward_only_in_main_process:
if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores)
else:
scores = all_scores[0].clone().detach()
else:
scores = all_scores[0].clone().detach()
scores = all_scores
scores_mask = scores != -np.inf

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
Expand Down

0 comments on commit 43ea9f1

Please sign in to comment.