Skip to content

Commit

Permalink
Llama 3.2 1B inference and tests (#499)
Browse files Browse the repository at this point in the history
* llama_inference

* refactoring model

* ficx for newer transformers

* test for llama32

* uncommented mark_fail

* updated load_model
  • Loading branch information
mstojkovicTT authored Oct 24, 2024
1 parent 9cae62d commit ce98259
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 48 deletions.
28 changes: 19 additions & 9 deletions forge/test/mlir/llama/test_llama_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from test.mlir.llama.utils.utils import load_model


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_inference():
# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
def test_llama_inference(model_path):
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Model and Tokenizer
framework_model, tokenizer = load_model(model_path)

prompt = "Q: What is the largest animal?\nA:"
Expand All @@ -27,17 +30,20 @@ def test_llama_inference():
compiled_model = forge.compile(framework_model, input_ids)


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.skip(reason="No need to run in CI, this is PoC that should be mapped to work on device.")
def test_llama_inference_no_cache_cpu():
def test_llama_inference_no_cache_cpu(model_path):
"""
This function tests the inference of the Llama 3B model without using a past-cache (KV cache).
It generates text token by token, which can slow down over time as the model has to compute
all key-value (KV) pairs for each new token. The function demonstrates how to load the model
and tokenizer, prepare an input prompt, and generate a sequence of tokens until a specified
maximum number of new tokens is reached or an end-of-sequence token is encountered.
"""
# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
framework_model, tokenizer = load_model(model_path)

# Prepare input sentence
Expand All @@ -61,8 +67,9 @@ def test_llama_inference_no_cache_cpu():
print(generated_text)


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.skip(reason="No need to run in CI, this is PoC that should be mapped to work on device.")
def test_llama_inference_cache_cpu():
def test_llama_inference_cache_cpu(model_path):
"""
This function tests the inference of the Llama 3B model using a past-cache (KV cache).
By utilizing cached key-value (KV) pairs, the model can generate text more efficiently
Expand All @@ -79,9 +86,12 @@ def test_llama_inference_cache_cpu():
5. Generate tokens iteratively, updating the past key-values and input IDs.
6. Decode the generated tokens into text and print the result.
"""
# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
framework_model, tokenizer = load_model(model_path)

# Prepare input sentence
prompt = "Q: What is the largest animal?\nA:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
Expand Down
11 changes: 8 additions & 3 deletions forge/test/mlir/llama/tests/test_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
from forge.op.eval.common import compare_with_golden_pcc


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_embedding():
# Load Llama 3B model and tokenizer
framework_model, _ = load_model()
def test_llama_embedding(model_path):
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
framework_model, _ = load_model(model_path)

vocab_size = framework_model.config.vocab_size
framework_model = framework_model.model.embed_tokens

Expand Down
14 changes: 10 additions & 4 deletions forge/test/mlir/llama/tests/test_llama_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
from forge.op.eval.common import compare_with_golden_pcc


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_lm_head():
# Load Llama 3B model and tokenizer
framework_model, _ = load_model()
def test_llama_lm_head(model_path):
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
framework_model, _ = load_model(model_path)

framework_model = framework_model.lm_head
input_features = framework_model.in_features

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
torch.rand((1, 12, input_features)), # Hidden states
]

# Sanity run
Expand Down
14 changes: 10 additions & 4 deletions forge/test/mlir/llama/tests/test_llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
from forge.op.eval.common import compare_with_golden_pcc


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_mlp():
# Load Llama 3B model and tokenizer
framework_model, _ = load_model()
def test_llama_mlp(model_path):
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
framework_model, _ = load_model(model_path)

framework_model = framework_model.model.layers[0].mlp
hidden_dim = framework_model.hidden_size

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
torch.rand((1, 12, hidden_dim)), # Hidden states
]

# Sanity run
Expand Down
13 changes: 8 additions & 5 deletions forge/test/mlir/llama/tests/test_llama_prefil.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,19 @@ def decode_on_cpu(model, tokenizer, input_ids, hidden_states, max_new_tokens):
return input_ids, output_logits


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_prefil_on_device_decode_on_cpu():
def test_llama_prefil_on_device_decode_on_cpu(model_path):
"""
This function tests the inference of the Llama 3B model split into two parts:
This function tests the inference of the Llama models split into two parts:
- The first part is the prefilling of the model on the device.
- The second part is the decoding of the model on the CPU without KV cache.
"""
# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
model, tokenizer = load_model(model_path, use_cache=False, return_dict=True)
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
model, tokenizer = load_model(model_path, return_dict=True)

# Prepare input sentence
prompt = "Q: What is the largest animal?\nA:"
Expand Down
14 changes: 10 additions & 4 deletions forge/test/mlir/llama/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
from forge.op.eval.common import compare_with_golden_pcc


def test_llama_lm_head():
# Load Llama 3B model and tokenizer
framework_model, _ = load_model()
@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
def test_llama_lm_head(model_path):
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama model and tokenizer
framework_model, _ = load_model(model_path)

framework_model = framework_model.model.norm
input_features = framework_model.weight.shape[0]

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
torch.rand((1, 12, input_features)), # Hidden states
]

# Sanity run
Expand Down
24 changes: 17 additions & 7 deletions forge/test/mlir/llama/tests/test_llama_rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from forge.op.eval.common import compare_with_golden_pcc


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_rotary_emb():
def test_llama_rotary_emb(model_path):
class Llama_Rotary_Embedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
Expand All @@ -20,24 +21,33 @@ def __init__(self, model):
self.seq_length = 12

def forward(self, query_states, key_states):
kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
position_ids = torch.arange(
self.past_key_values_length,
self.seq_length + self.past_key_values_length,
dtype=torch.long,
)
position_ids = position_ids.unsqueeze(0)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
return query_states, key_states

# Load the model
llama_model, _ = load_model()
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Load Llama Model
llama_model, _ = load_model(model_path)

framework_model = Llama_Rotary_Embedding(llama_model)
framework_model.eval()

# Input samples
batch_size, q_heads, kv_heads, query_seq_len, kv_seq_len, head_dim = 1, 32, 32, 12, 12, 100
config = llama_model.config
batch_size = 1
q_heads = config.num_attention_heads
kv_heads = config.num_key_value_heads
query_seq_len = framework_model.seq_length
kv_seq_len = framework_model.seq_length
head_dim = config.hidden_size // config.num_attention_heads
inputs = [
torch.rand((batch_size, q_heads, query_seq_len, head_dim)), # Query states
torch.rand((batch_size, kv_heads, kv_seq_len, head_dim)), # Key states
Expand Down
17 changes: 12 additions & 5 deletions forge/test/mlir/llama/tests/test_llama_self_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from forge.op.eval.common import compare_with_golden_pcc


@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b", "meta-llama/Llama-3.2-1B"])
@pytest.mark.xfail()
def test_llama_self_attn():
def test_llama_self_attn(model_path):
if model_path == "meta-llama/Llama-3.2-1B":
pytest.skip("Skipping test for Llama-3.2-1B model, waiting for new transformers version.")

# Define wrapper function
class SelfAttention(torch.nn.Module):
def __init__(self, model):
Expand All @@ -22,15 +26,18 @@ def forward(self, *inputs):

return hidden_states

# Load Llama 3B model and tokenizer
framework_model, _ = load_model()
# Load Llama model and tokenizer
framework_model, _ = load_model(model_path)
framework_model = SelfAttention(framework_model.model.layers[0].self_attn)

# Get hidden dimension
hidden_size = framework_model.model.config.hidden_size

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
torch.rand((1, 12, hidden_size)), # Hidden states
torch.ones((1, 1, 12, 12)), # Attention mask
torch.arange(12).unsqueeze(0), # Position IDs
torch.arange(12).unsqueeze(0).float(), # Position IDs
]

# Sanity run
Expand Down
12 changes: 5 additions & 7 deletions forge/test/mlir/llama/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

# SPDX-License-Identifier: Apache-2.0

from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from transformers import LlamaConfig, LlamaForCausalLM, AutoTokenizer

import forge


def load_model(model_path="openlm-research/open_llama_3b", **kwargs):
# Default config values
config = LlamaConfig()
config.hidden_size = 3200
config.intermediate_size = 8640
config.num_hidden_layers = 26
config.pad_token_id = 0
config = LlamaConfig.from_pretrained(model_path)

# Use defaults or values from kwargs
config.return_dict = kwargs.get("return_dict", False)
Expand All @@ -24,6 +20,8 @@ def load_model(model_path="openlm-research/open_llama_3b", **kwargs):
# Load the model
framework_model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", config=config)
framework_model.eval()
tokenizer = LlamaTokenizer.from_pretrained(model_path)

# Using AutoTokenizer for default tokenizers for both openllama and llama 3.2
tokenizer = AutoTokenizer.from_pretrained(model_path)

return framework_model, tokenizer

0 comments on commit ce98259

Please sign in to comment.