diff --git a/trlx/models/modeling_nemo_ppo.py b/trlx/models/modeling_nemo_ppo.py index 5ad044130..4291dcff5 100644 --- a/trlx/models/modeling_nemo_ppo.py +++ b/trlx/models/modeling_nemo_ppo.py @@ -52,7 +52,7 @@ from trlx.data.ppo_types import PPORLBatch from trlx.models.modeling_ppo import PPOConfig from trlx.utils import to_device, tree_map -from trlx.utils.modeling import logprobs_of_labels, whiten +from trlx.utils.modeling import logprobs_of_next_labels, whiten # Track a per dp rank RNG to sample different rollouts # per dp rank @@ -993,7 +993,7 @@ def loss_func(model_output): start = batch.query_tensors.shape[1] end = start + response_length - label_logprobs = logprobs_of_labels(logits[:, :-1, :], inputs[:, 1:]) + label_logprobs = logprobs_of_next_labels(logits, inputs) label_logprobs = label_logprobs[:, start:end] advantages, returns = self.ppo_config.get_advantages_and_returns( @@ -1079,11 +1079,11 @@ def ppo_postprocess(model_output): # to save memory if run_policy_model and compute_logprobs: - logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + logprobs = logprobs_of_next_labels(logits, tokens) return logprobs, dict(logprobs=logprobs, values=values) if run_reference_model and compute_logprobs: - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], tokens[:, 1:]) + ref_logprobs = logprobs_of_next_labels(ref_logits, tokens) return ref_logprobs, dict(ref_logprobs=ref_logprobs) return logits, {"logits": logits, "values": values, "ref_logits": ref_logits} diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 1a4801aaf..e4d09902e 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -27,7 +27,7 @@ from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer from trlx.utils import Clock, infinite_dataloader -from trlx.utils.modeling import RunningMoments, gather_dict, logprobs_of_labels +from trlx.utils.modeling import RunningMoments, gather_dict, logprobs_of_next_labels logger = logging.get_logger(__name__) @@ -163,7 +163,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]: logits = outputs.logits values_pred = outputs.value - logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) + logprobs = logprobs_of_next_labels(logits, decoder_input_ids) mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) start = 0 end = start + response_length @@ -181,7 +181,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]: logits = outputs.logits values_pred = outputs.value values_pred = values_pred[:, :-1] - logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + logprobs = logprobs_of_next_labels(logits, tokens) start = query_tensors.shape[1] - 1 end = start + response_length @@ -438,12 +438,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logits = ref_logits.to(device) if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + logprobs = logprobs_of_next_labels(logits, sample_outputs) + ref_logprobs = logprobs_of_next_labels(ref_logits, sample_outputs) else: # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + logprobs = logprobs_of_next_labels(logits, all_tokens) + ref_logprobs = logprobs_of_next_labels(ref_logits, all_tokens) n_samples: int = samples.shape[0] diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index aa9bba525..dbeefbc33 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -219,6 +219,13 @@ def logprobs_of_labels(logits, labels): return logprobs_labels.squeeze(-1) +def logprobs_of_next_labels(logits, labels): + """Log probabilities of the next labels, optimized for memory and speed""" + logprobs = F.log_softmax(logits, dim=-1) + logprobs_labels = torch.gather(logprobs, dim=-1, index=labels[..., 1:, None]) + return logprobs_labels.squeeze(-1) + + def flatten_dict( d: Union[dict, MutableMapping], parent_key: str = "",