From e047adf6a414817d77843ce696f8ed11f9949ec4 Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 18 Dec 2024 04:15:27 -0800 Subject: [PATCH] update tests --- tests/generation/test_candidate_generator.py | 72 +++++++++++++++----- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index dd7e427a3bfda9..9a2b2831e5d8e9 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -64,22 +64,27 @@ def setUp(self): self.target_tokenizer.get_vocab.return_value = self.target_vocab self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab + self.assistant_model_device = "cpu" + self.target_vocab_size = 6 # Instantiate the class under test self.translator = AssistantToTargetTranslator( - target_tokenizer=self.target_tokenizer, assistant_tokenizer=self.assistant_tokenizer + target_tokenizer=self.target_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, ) def test_get_assistant_to_target_input_ids(self): """Test the mapping from assistant tokens to target tokens.""" - expected_mapping = {0: 0, 1: 1, 2: 2} - actual_mapping = self.translator._assistant_to_target_input_ids + expected_mapping = [0, 1, 2, self.translator.suppress_tokens_id, self.translator.suppress_tokens_id] + actual_mapping = self.translator._assistant_to_target_input_ids.tolist() self.assertEqual(actual_mapping, expected_mapping) def test_get_suppress_input_ids(self): """Test the suppression of assistant input IDs not present in the target vocabulary.""" - expected_suppress_ids = [4] - actual_suppress_ids = self.translator._suppress_input_ids + expected_suppress_ids = [3, 4] + actual_suppress_ids = self.translator._get_suppress_input_ids().tolist() self.assertEqual(actual_suppress_ids, expected_suppress_ids) def test_get_target_ids(self): @@ -89,8 +94,8 @@ def test_get_target_ids(self): assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]) # 'hello world foo baz' in assistant tokenizer expected_target_ids = torch.LongTensor( - [[0, 1, 2, 4]] - ) # 'hello world foo baz' in target tokenizer (baz id remains 4) + [[0, 1, 2, self.translator.suppress_tokens_id]] + ) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab) actual_target_ids = self.translator.get_target_ids( assistant_input_ids, target_input_ids, assistant_candidate_ids @@ -100,10 +105,10 @@ def test_get_target_ids(self): def test_get_target_logits(self): """Test the conversion of assistant logits to target logits.""" # Assistant logits for IDs 0, 1, 2 - assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3]]]) # Shape (1, 1, 3) + assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.filter_value]]]) # Shape (1, 1, 5) # Expected target logits (target_vocab_size = 4) - expected_target_logits = torch.full((1, 1, 4), -float("inf")) + expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.filter_value) expected_target_logits[0, 0, 0] = 0.1 # 'hello' expected_target_logits[0, 0, 1] = 0.2 # 'world' expected_target_logits[0, 0, 2] = 0.3 # 'foo' @@ -132,18 +137,38 @@ def setUp(self): self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) + self.assistant_model_device = "cpu" + self.target_vocab_size = 6 def test_same_instance_for_same_tokenizers(self): """Test that the same translator is returned for the same tokenizers.""" - translator1 = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer) - translator2 = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer) + translator1 = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) + translator2 = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) self.assertIs(translator1, translator2, "Translators should be cached and identical") def test_different_instances_for_different_tokenizers(self): """Test that different tokenizers produce different translators.""" - translator1 = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer) + translator1 = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) translator2 = AssistantVocabTranslatorCache.get_translator( - self.other_target_tokenizer, self.other_assistant_tokenizer + self.other_target_tokenizer, + self.other_assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, ) self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") @@ -154,7 +179,12 @@ def test_cache_with_weakref_key(self): assistant_tokenizer = MockTokenizer({"hello": 0}) # Store translator in a local variable to avoid it being kept alive - translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer) + translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, + assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) # Delete all strong references @@ -177,7 +207,12 @@ def test_weakref_cache_cleanup(self): def create_translator(): target_tokenizer = MockTokenizer({"hello": 0}) assistant_tokenizer = MockTokenizer({"hello": 0}) - translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer) + translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, + assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) # Create weak references before returning refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) # Remove strong references inside the function @@ -204,7 +239,12 @@ def test_thread_safety(self): translators = [] def get_translator(): - translator = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer) + translator = AssistantVocabTranslatorCache.get_translator( + self.target_tokenizer, + self.assistant_tokenizer, + assistant_model_device=self.assistant_model_device, + target_vocab_size=self.target_vocab_size, + ) translators.append(translator) threads = [threading.Thread(target=get_translator) for _ in range(10)]