Skip to content

Commit

Permalink
Correct shape (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikihowe authored Jan 4, 2024
1 parent 20428c4 commit dc53b8c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,15 +1045,15 @@ def train_minibatch(
Args:
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape [batch_size, response_length]
Log probabilities of the model, shape [mini_batch_size, response_length]
values (`torch.FloatTensor`):
Values of the value head, shape [batch_size, response_length]
Values of the value head, shape [mini_batch_size, response_length]
query (`torch.LongTensor`):
Encoded queries, shape [batch_size, query_length]
Encoded queries, shape [mini_batch_size, query_length]
response (`torch.LongTensor`):
Encoded responses, shape [batch_size, response_length]
Encoded responses, shape [mini_batch_size, response_length]
model_input (`torch.LongTensor`):
Concatenated queries and responses, shape [batch_size, query_length+response_length]
Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
Returns:
train_stats (dict[str, `torch.Tensor`]):
Expand Down

0 comments on commit dc53b8c

Please sign in to comment.