Skip to content

Commit

Permalink
Merge pull request #8 from keyboardAnt/unit_tests_usd
Browse files Browse the repository at this point in the history
Add unittests for Universal Assisted generation
  • Loading branch information
gauravjain14 authored Dec 19, 2024
2 parents e047adf + 701edbb commit 7088978
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import gc
import logging
import threading
import unittest
import weakref
from unittest.mock import MagicMock

from zmq import device

import numpy as np
import torch

from transformers.generation.candidate_generator import (
AssistantToTargetTranslator,
AssistantVocabTranslatorCache,
AssistedCandidateGeneratorDifferentTokenizers,
UniversalSpeculativeDecodingGenerator
)

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig


class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
Expand Down Expand Up @@ -256,3 +262,87 @@ def get_translator():
# All translators should be the same instance
for translator in translators:
self.assertIs(translators[0], translator, "All translators should be identical across threads")


class TestUniversalSpeculativeDecoding(unittest.TestCase):
device = "cuda" if torch.cuda.is_available() else "cpu"

@classmethod
def setUpClass(cls):
cls.assistant_model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-gpt2").to(cls.device)
cls.main_tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct")
cls.assistant_tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-gpt2")
cls.generation_config = GenerationConfig()

# Ensure required tokens exist
if cls.main_tokenizer.pad_token_id is None:
cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id
if cls.main_tokenizer.bos_token_id is None:
cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id

def setUp(self):
self.input_ids = torch.tensor([[1, 2, 3]]).to(self.device)
self.model_kwargs = {
"attention_mask": torch.ones_like(self.input_ids).to(self.device),
}
self.generator = UniversalSpeculativeDecodingGenerator(
input_ids=self.input_ids,
assistant_model=self.assistant_model,
target_tokenizer=self.main_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
generation_config=self.generation_config,
model_kwargs=self.model_kwargs,
target_vocab_size=self.main_tokenizer.vocab_size,
)

def test_basic_generation(self):
"""Test basic speculative decoding works"""
input_text = "The quick brown fox"
input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt")
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)

self.assertIsNotNone(candidates)
self.assertIsNotNone(scores)
self.assertTrue(torch.is_tensor(candidates))
self.assertTrue(torch.is_tensor(scores))

def test_mismatched_vocabularies(self):
"""Test handling of mismatched vocabularies between models"""
# Create input with tokens present in main but not assistant vocab
# Find a token that is not in the assistant tokenizer but in
# the main tokenizer.
missing_token = next(
token for token in self.main_tokenizer.get_vocab()
if token not in self.assistant_tokenizer.get_vocab() and
token not in self.main_tokenizer.all_special_tokens and
"reserved_" not in token
)
input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]])
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)

def test_speculation_depth(self):
"""Test different speculation depths"""
input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt")
self.generator.input_ids = input_ids

for depth in [1, 8, 17]:
self.generator.num_assistant_tokens = depth
candidates, scores = self.generator.get_candidates(input_ids)
self.assertLessEqual(
candidates.shape[1] - input_ids.shape[1], depth
)

def test_device_consistency(self):
"""Test handling of inputs on different devices"""
if torch.cuda.is_available():
input_ids = torch.tensor([[1, 2, 3]]).to(
self.generator.assistant_model.device)
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertEqual(candidates.device, input_ids.device)

0 comments on commit 7088978

Please sign in to comment.