diff --git a/tests/test_utils.py b/tests/test_utils.py index 226861d96f..ba605ab086 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,13 +16,14 @@ import torch from datasets import load_dataset -from transformers import AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl.trainer.model_config import ModelConfig from trl.trainer.utils import ( DataCollatorForChatML, + batch_generation, decode_and_strip_padding, generate_model_card, get_peft_config, @@ -250,3 +251,62 @@ def test_data_collator_for_chatml(self): # Verify that EOS token is at the end of labels self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.") + + +class TestBatchGeneration(unittest.TestCase): + def setUp(self): + # Initialize the tokenizer + self.model_id = "Qwen/Qwen2-0.5B-Instruct" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + self.generation_config = GenerationConfig( + max_new_tokens=128, + temperature=0.5, + do_sample=True, + top_k=0, + pad_token_id=self.tokenizer.pad_token_id, + ) + + # Example input + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + self.examples = dataset["messages"] + self.mini_batch_size = 3 + + def test_mini_batch_generation(self): + batch = [ + self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False) + for example in self.examples + ] + queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"] + bs, context_length = queries.shape + + query_responses, logits = batch_generation( + self.model, queries, self.mini_batch_size, self.tokenizer.pad_token_id, self.generation_config + ) + + max_length_query = query_responses.shape[1] + max_length_logits = max_length_query - context_length + + self.assertGreater(max_length_query, context_length) + self.assertEqual(query_responses.shape, (bs, max_length_query)) + self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + + def test_single_batch_generation(self): + batch = [ + self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False) + for example in self.examples + ] + queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"] + bs, context_length = queries.shape + + query_responses, logits = batch_generation( + self.model, queries, bs, self.tokenizer.pad_token_id, self.generation_config + ) + + max_length_query = query_responses.shape[1] + max_length_logits = max_length_query - context_length + + self.assertGreater(max_length_query, context_length) + self.assertEqual(query_responses.shape, (bs, max_length_query)) + self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 81e826f971..96dc8fba24 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1274,7 +1274,8 @@ def batch_generation( ): query_responses = [] logitss = [] - for i in range(0, queries.shape[0], local_rollout_forward_batch_size): + batch_size = queries.shape[0] + for i in range(0, batch_size, local_rollout_forward_batch_size): query = queries[i : i + local_rollout_forward_batch_size] query_response, logits = generate( model, @@ -1284,7 +1285,16 @@ def batch_generation( ) query_responses.append(query_response) logitss.append(logits) - return torch.cat(query_responses, 0), torch.cat(logitss, 0) + + # padding tensors + padded_query_responses = pad(query_responses, padding_value=pad_token_id, padding_side="right") + padded_logitss = pad(logitss, padding_value=0, padding_side="right") + + # reshaping + padded_query_responses = padded_query_responses.view(-1, padded_query_responses.shape[-1])[:batch_size] + padded_logitss = padded_logitss.view(-1, *padded_logitss.shape[2:])[:batch_size] + + return padded_query_responses, padded_logitss def add_bos_token_if_needed(