Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jmamou committed Dec 18, 2024
1 parent a350b1c commit e047adf
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,27 @@ def setUp(self):

self.target_tokenizer.get_vocab.return_value = self.target_vocab
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
self.assistant_model_device = "cpu"
self.target_vocab_size = 6

# Instantiate the class under test
self.translator = AssistantToTargetTranslator(
target_tokenizer=self.target_tokenizer, assistant_tokenizer=self.assistant_tokenizer
target_tokenizer=self.target_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)

def test_get_assistant_to_target_input_ids(self):
"""Test the mapping from assistant tokens to target tokens."""
expected_mapping = {0: 0, 1: 1, 2: 2}
actual_mapping = self.translator._assistant_to_target_input_ids
expected_mapping = [0, 1, 2, self.translator.suppress_tokens_id, self.translator.suppress_tokens_id]
actual_mapping = self.translator._assistant_to_target_input_ids.tolist()
self.assertEqual(actual_mapping, expected_mapping)

def test_get_suppress_input_ids(self):
"""Test the suppression of assistant input IDs not present in the target vocabulary."""
expected_suppress_ids = [4]
actual_suppress_ids = self.translator._suppress_input_ids
expected_suppress_ids = [3, 4]
actual_suppress_ids = self.translator._get_suppress_input_ids().tolist()
self.assertEqual(actual_suppress_ids, expected_suppress_ids)

def test_get_target_ids(self):
Expand All @@ -89,8 +94,8 @@ def test_get_target_ids(self):
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]) # 'hello world foo baz' in assistant tokenizer

expected_target_ids = torch.LongTensor(
[[0, 1, 2, 4]]
) # 'hello world foo baz' in target tokenizer (baz id remains 4)
[[0, 1, 2, self.translator.suppress_tokens_id]]
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)

actual_target_ids = self.translator.get_target_ids(
assistant_input_ids, target_input_ids, assistant_candidate_ids
Expand All @@ -100,10 +105,10 @@ def test_get_target_ids(self):
def test_get_target_logits(self):
"""Test the conversion of assistant logits to target logits."""
# Assistant logits for IDs 0, 1, 2
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3]]]) # Shape (1, 1, 3)
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.filter_value]]]) # Shape (1, 1, 5)

# Expected target logits (target_vocab_size = 4)
expected_target_logits = torch.full((1, 1, 4), -float("inf"))
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.filter_value)
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
expected_target_logits[0, 0, 1] = 0.2 # 'world'
expected_target_logits[0, 0, 2] = 0.3 # 'foo'
Expand Down Expand Up @@ -132,18 +137,38 @@ def setUp(self):
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
self.assistant_model_device = "cpu"
self.target_vocab_size = 6

def test_same_instance_for_same_tokenizers(self):
"""Test that the same translator is returned for the same tokenizers."""
translator1 = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer)
translator2 = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer)
translator1 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
translator2 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertIs(translator1, translator2, "Translators should be cached and identical")

def test_different_instances_for_different_tokenizers(self):
"""Test that different tokenizers produce different translators."""
translator1 = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer)
translator1 = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
translator2 = AssistantVocabTranslatorCache.get_translator(
self.other_target_tokenizer, self.other_assistant_tokenizer
self.other_target_tokenizer,
self.other_assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")

Expand All @@ -154,7 +179,12 @@ def test_cache_with_weakref_key(self):
assistant_tokenizer = MockTokenizer({"hello": 0})

# Store translator in a local variable to avoid it being kept alive
translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer)
translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)

# Delete all strong references
Expand All @@ -177,7 +207,12 @@ def test_weakref_cache_cleanup(self):
def create_translator():
target_tokenizer = MockTokenizer({"hello": 0})
assistant_tokenizer = MockTokenizer({"hello": 0})
translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer)
translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
# Create weak references before returning
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
# Remove strong references inside the function
Expand All @@ -204,7 +239,12 @@ def test_thread_safety(self):
translators = []

def get_translator():
translator = AssistantVocabTranslatorCache.get_translator(self.target_tokenizer, self.assistant_tokenizer)
translator = AssistantVocabTranslatorCache.get_translator(
self.target_tokenizer,
self.assistant_tokenizer,
assistant_model_device=self.assistant_model_device,
target_vocab_size=self.target_vocab_size,
)
translators.append(translator)

threads = [threading.Thread(target=get_translator) for _ in range(10)]
Expand Down

0 comments on commit e047adf

Please sign in to comment.