From 4e92e9ced98340f3b183453ba8e67a0438851e11 Mon Sep 17 00:00:00 2001 From: Gaurav Date: Thu, 12 Dec 2024 02:30:25 +0000 Subject: [PATCH 1/4] Add unittests for Universal Assisted generation --- tests/test_universal_assisted_generation.py | 119 ++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/test_universal_assisted_generation.py diff --git a/tests/test_universal_assisted_generation.py b/tests/test_universal_assisted_generation.py new file mode 100644 index 00000000000000..8eae5a71de0749 --- /dev/null +++ b/tests/test_universal_assisted_generation.py @@ -0,0 +1,119 @@ +import unittest + +from zmq import device +import torch +import logging +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +from transformers.generation.candidate_generator import UniversalSpeculativeDecodingGenerator + +logging.basicConfig(level=logging.DEBUG, format='%(message)s') + +if torch.cuda.is_available(): + device = "cuda" + +class TestUniversalSpeculativeDecoding(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Setup main and assistant models + cls.main_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct").to(device) + cls.assistant_model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2.5-0.5B-Instruct").to(device) + cls.main_tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct") + cls.assistant_tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B-Instruct") + cls.generation_config = GenerationConfig() + + # Ensure required tokens exist + if cls.main_tokenizer.pad_token_id is None: + cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id + if cls.main_tokenizer.bos_token_id is None: + cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id + + def setUp(self): + self.input_ids = torch.tensor([[1, 2, 3]]).to(device) + self.model_kwargs = { + "attention_mask": torch.ones_like(self.input_ids).to(device), + } + self.generator = UniversalSpeculativeDecodingGenerator( + input_ids=self.input_ids, + assistant_model=self.assistant_model, + target_tokenizer=self.main_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + generation_config=self.generation_config, + model_kwargs=self.model_kwargs, + target_vocab_size=self.main_tokenizer.vocab_size, + ) + + def test_basic_generation(self): + """Test basic speculative decoding works""" + input_text = "The quick brown fox" + input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt") + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + + self.assertIsNotNone(candidates) + self.assertIsNotNone(scores) + self.assertTrue(torch.is_tensor(candidates)) + self.assertTrue(torch.is_tensor(scores)) + + def test_mismatched_vocabularies(self): + """Test handling of mismatched vocabularies between models""" + # Create input with tokens present in main but not assistant vocab + # Find a token that is not in the assistant tokenizer but in + # the main tokenizer. + missing_token = next( + token + for token in self.main_tokenizer.get_vocab() + if token not in self.assistant_tokenizer.get_vocab() + ) + + input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + self.assertIsNotNone(candidates) + + def test_empty_input(self): + if False: + """Test handling of empty input""" + input_ids = torch.tensor([[]], dtype=torch.long) + self.generator.input_ids = input_ids + with self.assertRaises(ValueError): + self.generator.get_candidates(input_ids) + + def test_long_sequence(self): + if False: + """Test handling of very long input sequences""" + long_input = torch.ones((1, 2048), dtype=torch.long) + self.generator.input_ids = long_input + candidates, scores = self.generator.get_candidates(long_input) + self.assertLessEqual( + candidates.shape[1], + self.main_model.config.max_position_embeddings, + ) + + def test_speculation_depth(self): + """Test different speculation depths""" + input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt") + self.generator.input_ids = input_ids + + for depth in [1, 8, 17]: + self.generation_config.num_assistant_tokens = depth + candidates, scores = self.generator.get_candidates(input_ids) + self.assertLessEqual( + candidates.shape[1] - input_ids.shape[1], depth + ) + + def test_device_consistency(self): + """Test handling of inputs on different devices""" + if torch.cuda.is_available(): + input_ids = torch.tensor([[1, 2, 3]]).to( + self.generator.assistant_model.device) + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + self.assertEqual(candidates.device, input_ids.device) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 011f5956d23ab308855d3d2cb7c3cf31c9571a28 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Tue, 17 Dec 2024 22:30:19 +0000 Subject: [PATCH 2/4] Remove unused import and fix `test_speculation_depth` test --- tests/test_universal_assisted_generation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_universal_assisted_generation.py b/tests/test_universal_assisted_generation.py index 8eae5a71de0749..e6d2ea7ec30e22 100644 --- a/tests/test_universal_assisted_generation.py +++ b/tests/test_universal_assisted_generation.py @@ -1,6 +1,5 @@ import unittest -from zmq import device import torch import logging from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig @@ -8,8 +7,7 @@ logging.basicConfig(level=logging.DEBUG, format='%(message)s') -if torch.cuda.is_available(): - device = "cuda" +device = "cuda" if torch.cuda.is_available() else "cpu" class TestUniversalSpeculativeDecoding(unittest.TestCase): @classmethod @@ -99,7 +97,7 @@ def test_speculation_depth(self): self.generator.input_ids = input_ids for depth in [1, 8, 17]: - self.generation_config.num_assistant_tokens = depth + self.generator.num_assistant_tokens = depth candidates, scores = self.generator.get_candidates(input_ids) self.assertLessEqual( candidates.shape[1] - input_ids.shape[1], depth From 26524900c4ac2295ed40de96606ffab5a8a4789c Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Wed, 18 Dec 2024 21:00:37 +0000 Subject: [PATCH 3/4] exclude special and reserved tokens from tokenizer for UAG --- tests/test_universal_assisted_generation.py | 35 +++++---------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/tests/test_universal_assisted_generation.py b/tests/test_universal_assisted_generation.py index e6d2ea7ec30e22..8c45a0ba148622 100644 --- a/tests/test_universal_assisted_generation.py +++ b/tests/test_universal_assisted_generation.py @@ -1,11 +1,9 @@ import unittest import torch -import logging from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig from transformers.generation.candidate_generator import UniversalSpeculativeDecodingGenerator -logging.basicConfig(level=logging.DEBUG, format='%(message)s') device = "cuda" if torch.cuda.is_available() else "cpu" @@ -16,11 +14,11 @@ def setUpClass(cls): cls.main_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-1B-Instruct").to(device) cls.assistant_model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct").to(device) + "hf-internal-testing/tiny-random-gpt2").to(device) cls.main_tokenizer = AutoTokenizer.from_pretrained( "meta-llama/Llama-3.2-1B-Instruct") cls.assistant_tokenizer = AutoTokenizer.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct") + "hf-internal-testing/tiny-random-gpt2") cls.generation_config = GenerationConfig() # Ensure required tokens exist @@ -62,35 +60,16 @@ def test_mismatched_vocabularies(self): # Find a token that is not in the assistant tokenizer but in # the main tokenizer. missing_token = next( - token - for token in self.main_tokenizer.get_vocab() - if token not in self.assistant_tokenizer.get_vocab() + token for token in self.main_tokenizer.get_vocab() + if token not in self.assistant_tokenizer.get_vocab() and + token not in self.main_tokenizer.all_special_tokens and + "reserved_" not in token ) - - input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) + input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) self.generator.input_ids = input_ids candidates, scores = self.generator.get_candidates(input_ids) self.assertIsNotNone(candidates) - def test_empty_input(self): - if False: - """Test handling of empty input""" - input_ids = torch.tensor([[]], dtype=torch.long) - self.generator.input_ids = input_ids - with self.assertRaises(ValueError): - self.generator.get_candidates(input_ids) - - def test_long_sequence(self): - if False: - """Test handling of very long input sequences""" - long_input = torch.ones((1, 2048), dtype=torch.long) - self.generator.input_ids = long_input - candidates, scores = self.generator.get_candidates(long_input) - self.assertLessEqual( - candidates.shape[1], - self.main_model.config.max_position_embeddings, - ) - def test_speculation_depth(self): """Test different speculation depths""" input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt") From 701edbb522c62dffba22ebc2f84d8a7d3107a38f Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 19 Dec 2024 08:51:57 +0000 Subject: [PATCH 4/4] mv `test_universal_assisted_generation.py` to `generation/test_candidate_generator.py` --- tests/generation/test_candidate_generator.py | 90 ++++++++++++++++++ tests/test_universal_assisted_generation.py | 96 -------------------- 2 files changed, 90 insertions(+), 96 deletions(-) delete mode 100644 tests/test_universal_assisted_generation.py diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index dd7e427a3bfda9..7d005f42536ab9 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -1,9 +1,12 @@ import gc +import logging import threading import unittest import weakref from unittest.mock import MagicMock +from zmq import device + import numpy as np import torch @@ -11,8 +14,11 @@ AssistantToTargetTranslator, AssistantVocabTranslatorCache, AssistedCandidateGeneratorDifferentTokenizers, + UniversalSpeculativeDecodingGenerator ) +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig + class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): def test_no_intersection(self): @@ -216,3 +222,87 @@ def get_translator(): # All translators should be the same instance for translator in translators: self.assertIs(translators[0], translator, "All translators should be identical across threads") + + +class TestUniversalSpeculativeDecoding(unittest.TestCase): + device = "cuda" if torch.cuda.is_available() else "cpu" + + @classmethod + def setUpClass(cls): + cls.assistant_model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-gpt2").to(cls.device) + cls.main_tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct") + cls.assistant_tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-gpt2") + cls.generation_config = GenerationConfig() + + # Ensure required tokens exist + if cls.main_tokenizer.pad_token_id is None: + cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id + if cls.main_tokenizer.bos_token_id is None: + cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id + + def setUp(self): + self.input_ids = torch.tensor([[1, 2, 3]]).to(self.device) + self.model_kwargs = { + "attention_mask": torch.ones_like(self.input_ids).to(self.device), + } + self.generator = UniversalSpeculativeDecodingGenerator( + input_ids=self.input_ids, + assistant_model=self.assistant_model, + target_tokenizer=self.main_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + generation_config=self.generation_config, + model_kwargs=self.model_kwargs, + target_vocab_size=self.main_tokenizer.vocab_size, + ) + + def test_basic_generation(self): + """Test basic speculative decoding works""" + input_text = "The quick brown fox" + input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt") + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + + self.assertIsNotNone(candidates) + self.assertIsNotNone(scores) + self.assertTrue(torch.is_tensor(candidates)) + self.assertTrue(torch.is_tensor(scores)) + + def test_mismatched_vocabularies(self): + """Test handling of mismatched vocabularies between models""" + # Create input with tokens present in main but not assistant vocab + # Find a token that is not in the assistant tokenizer but in + # the main tokenizer. + missing_token = next( + token for token in self.main_tokenizer.get_vocab() + if token not in self.assistant_tokenizer.get_vocab() and + token not in self.main_tokenizer.all_special_tokens and + "reserved_" not in token + ) + input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + self.assertIsNotNone(candidates) + + def test_speculation_depth(self): + """Test different speculation depths""" + input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt") + self.generator.input_ids = input_ids + + for depth in [1, 8, 17]: + self.generator.num_assistant_tokens = depth + candidates, scores = self.generator.get_candidates(input_ids) + self.assertLessEqual( + candidates.shape[1] - input_ids.shape[1], depth + ) + + def test_device_consistency(self): + """Test handling of inputs on different devices""" + if torch.cuda.is_available(): + input_ids = torch.tensor([[1, 2, 3]]).to( + self.generator.assistant_model.device) + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + self.assertEqual(candidates.device, input_ids.device) diff --git a/tests/test_universal_assisted_generation.py b/tests/test_universal_assisted_generation.py deleted file mode 100644 index 8c45a0ba148622..00000000000000 --- a/tests/test_universal_assisted_generation.py +++ /dev/null @@ -1,96 +0,0 @@ -import unittest - -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig -from transformers.generation.candidate_generator import UniversalSpeculativeDecodingGenerator - - -device = "cuda" if torch.cuda.is_available() else "cpu" - -class TestUniversalSpeculativeDecoding(unittest.TestCase): - @classmethod - def setUpClass(cls): - # Setup main and assistant models - cls.main_model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B-Instruct").to(device) - cls.assistant_model = AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-gpt2").to(device) - cls.main_tokenizer = AutoTokenizer.from_pretrained( - "meta-llama/Llama-3.2-1B-Instruct") - cls.assistant_tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-gpt2") - cls.generation_config = GenerationConfig() - - # Ensure required tokens exist - if cls.main_tokenizer.pad_token_id is None: - cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id - if cls.main_tokenizer.bos_token_id is None: - cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id - - def setUp(self): - self.input_ids = torch.tensor([[1, 2, 3]]).to(device) - self.model_kwargs = { - "attention_mask": torch.ones_like(self.input_ids).to(device), - } - self.generator = UniversalSpeculativeDecodingGenerator( - input_ids=self.input_ids, - assistant_model=self.assistant_model, - target_tokenizer=self.main_tokenizer, - assistant_tokenizer=self.assistant_tokenizer, - generation_config=self.generation_config, - model_kwargs=self.model_kwargs, - target_vocab_size=self.main_tokenizer.vocab_size, - ) - - def test_basic_generation(self): - """Test basic speculative decoding works""" - input_text = "The quick brown fox" - input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt") - self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) - - self.assertIsNotNone(candidates) - self.assertIsNotNone(scores) - self.assertTrue(torch.is_tensor(candidates)) - self.assertTrue(torch.is_tensor(scores)) - - def test_mismatched_vocabularies(self): - """Test handling of mismatched vocabularies between models""" - # Create input with tokens present in main but not assistant vocab - # Find a token that is not in the assistant tokenizer but in - # the main tokenizer. - missing_token = next( - token for token in self.main_tokenizer.get_vocab() - if token not in self.assistant_tokenizer.get_vocab() and - token not in self.main_tokenizer.all_special_tokens and - "reserved_" not in token - ) - input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) - self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) - self.assertIsNotNone(candidates) - - def test_speculation_depth(self): - """Test different speculation depths""" - input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt") - self.generator.input_ids = input_ids - - for depth in [1, 8, 17]: - self.generator.num_assistant_tokens = depth - candidates, scores = self.generator.get_candidates(input_ids) - self.assertLessEqual( - candidates.shape[1] - input_ids.shape[1], depth - ) - - def test_device_consistency(self): - """Test handling of inputs on different devices""" - if torch.cuda.is_available(): - input_ids = torch.tensor([[1, 2, 3]]).to( - self.generator.assistant_model.device) - self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) - self.assertEqual(candidates.device, input_ids.device) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file