diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index e74785e29e9008..fbe73248d00b7c 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -16,9 +16,11 @@ import unittest +from packaging import version from parameterized import parameterized from transformers import OlmoConfig, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast from transformers.testing_utils import ( @@ -449,3 +451,65 @@ def test_simple_encode_decode(self): self.assertEqual(rust_tokenizer.encode(" "), [50276]) self.assertEqual(rust_tokenizer.encode(" Hello"), [24387]) + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + olmo_model = "allenai/OLMo-1B-hf" + + tokenizer = AutoTokenizer.from_pretrained(olmo_model, pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that \nthe speed of light is the same in all reference frames.\n\nThe speed of light", + ] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = OlmoForCausalLM.from_pretrained( + olmo_model, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + eager + eager_generated_ids = model.generate( + **prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation + ) + eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)