Skip to content

Commit

Permalink
Add CPU test for llama_3.1 and llama_3.2 model (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdeviTT authored Dec 20, 2024
1 parent 0da336f commit 032f964
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions forge/test/models/pytorch/text/llama/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
import forge
from transformers.models.llama.modeling_llama import LlamaModel, Cache, StaticCache, AttentionMaskConverter

variants = ["meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-8B-Instruct"]
variants = [
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B-Instruct",
]


# Monkey Patching Casual Mask Update
Expand Down Expand Up @@ -148,7 +155,7 @@ def test_llama3_causal_lm(variant, test_device):
compiled_model = forge.compile(
framework_model,
sample_inputs=inputs,
module_name="pt_" + str(variant.split("/")[-1].replace("-", "_")) + "_causal_lm",
module_name="pt_" + (str(variant.split("/")[-1].replace("-", "_"))).replace(".", "_") + "_causal_lm",
)


Expand Down Expand Up @@ -179,5 +186,5 @@ def test_llama3_sequence_classification(variant, test_device):
compiled_model = forge.compile(
framework_model,
sample_inputs=inputs,
module_name="pt_" + str(variant.split("/")[-1].replace("-", "_")) + "_seq_cls",
module_name="pt_" + (str(variant.split("/")[-1].replace("-", "_"))).replace(".", "_") + "_seq_cls",
)

0 comments on commit 032f964

Please sign in to comment.