Skip to content

Commit

Permalink
Olmo is ExecuTorch Compatible (huggingface#34181)
Browse files Browse the repository at this point in the history
Co-authored-by: Guang Yang <[email protected]>
  • Loading branch information
2 people authored and BernardZach committed Dec 6, 2024
1 parent 897a3bc commit ecdda01
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/models/olmo/test_modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import unittest

from packaging import version
from parameterized import parameterized

from transformers import OlmoConfig, is_torch_available, set_seed
from transformers.generation.configuration_utils import GenerationConfig
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
from transformers.testing_utils import (
Expand Down Expand Up @@ -449,3 +451,65 @@ def test_simple_encode_decode(self):
self.assertEqual(rust_tokenizer.encode(" "), [50276])

self.assertEqual(rust_tokenizer.encode(" Hello"), [24387])

@slow
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

olmo_model = "allenai/OLMo-1B-hf"

tokenizer = AutoTokenizer.from_pretrained(olmo_model, pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that \nthe speed of light is the same in all reference frames.\n\nThe speed of light",
]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]

# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = OlmoForCausalLM.from_pretrained(
olmo_model,
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
)

prompts = ["Simply put, the theory of relativity states that "]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + eager
eager_generated_ids = model.generate(
**prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation
)
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

0 comments on commit ecdda01

Please sign in to comment.