Skip to content

Commit

Permalink
Adjust padding in batch generation (#2251)
Browse files Browse the repository at this point in the history
* pad batch generation

* Use pad utility

Co-authored-by: Kashif Rasul <[email protected]>

* Update trl/trainer/utils.py

* reshaping

* fix test_utils.py

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
gaetanlop and kashif authored Oct 22, 2024
1 parent d843b3d commit f2349d2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
62 changes: 61 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
decode_and_strip_padding,
generate_model_card,
get_peft_config,
Expand Down Expand Up @@ -250,3 +251,62 @@ def test_data_collator_for_chatml(self):

# Verify that EOS token is at the end of labels
self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.")


class TestBatchGeneration(unittest.TestCase):
def setUp(self):
# Initialize the tokenizer
self.model_id = "Qwen/Qwen2-0.5B-Instruct"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

self.generation_config = GenerationConfig(
max_new_tokens=128,
temperature=0.5,
do_sample=True,
top_k=0,
pad_token_id=self.tokenizer.pad_token_id,
)

# Example input
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
self.examples = dataset["messages"]
self.mini_batch_size = 3

def test_mini_batch_generation(self):
batch = [
self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False)
for example in self.examples
]
queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"]
bs, context_length = queries.shape

query_responses, logits = batch_generation(
self.model, queries, self.mini_batch_size, self.tokenizer.pad_token_id, self.generation_config
)

max_length_query = query_responses.shape[1]
max_length_logits = max_length_query - context_length

self.assertGreater(max_length_query, context_length)
self.assertEqual(query_responses.shape, (bs, max_length_query))
self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size))

def test_single_batch_generation(self):
batch = [
self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False)
for example in self.examples
]
queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"]
bs, context_length = queries.shape

query_responses, logits = batch_generation(
self.model, queries, bs, self.tokenizer.pad_token_id, self.generation_config
)

max_length_query = query_responses.shape[1]
max_length_logits = max_length_query - context_length

self.assertGreater(max_length_query, context_length)
self.assertEqual(query_responses.shape, (bs, max_length_query))
self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size))
14 changes: 12 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,8 @@ def batch_generation(
):
query_responses = []
logitss = []
for i in range(0, queries.shape[0], local_rollout_forward_batch_size):
batch_size = queries.shape[0]
for i in range(0, batch_size, local_rollout_forward_batch_size):
query = queries[i : i + local_rollout_forward_batch_size]
query_response, logits = generate(
model,
Expand All @@ -1284,7 +1285,16 @@ def batch_generation(
)
query_responses.append(query_response)
logitss.append(logits)
return torch.cat(query_responses, 0), torch.cat(logitss, 0)

# padding tensors
padded_query_responses = pad(query_responses, padding_value=pad_token_id, padding_side="right")
padded_logitss = pad(logitss, padding_value=0, padding_side="right")

# reshaping
padded_query_responses = padded_query_responses.view(-1, padded_query_responses.shape[-1])[:batch_size]
padded_logitss = padded_logitss.view(-1, *padded_logitss.shape[2:])[:batch_size]

return padded_query_responses, padded_logitss


def add_bos_token_if_needed(
Expand Down

0 comments on commit f2349d2

Please sign in to comment.