Skip to content

Commit

Permalink
update model
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Nov 18, 2024
1 parent e597949 commit 6c96bd0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/python_tests/test_cache_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t

@pytest.fixture(scope='module')
def phi3_converted_model(tmp_path_factory):
model_id = "microsoft/Phi-3-mini-4k-instruct"
model_id = "meta-llama/Llama-3.2-3B-Instruct"
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id
Expand All @@ -167,7 +167,7 @@ def phi3_converted_model(tmp_path_factory):
@pytest.mark.precommit
@pytest.mark.parametrize("subset", ["samsum", "qmsum", "trec", "qasper", "hotpotqa", "repobench-p"])
def test_unoptimized_generation_longbench(phi3_converted_model, subset):
seqs_per_request = 2
seqs_per_request = 32
num_kv_blocks = 1000
scheduler_config = get_scheduler_config(num_kv_blocks)
models_path = phi3_converted_model.models_path
Expand Down
8 changes: 5 additions & 3 deletions tests/python_tests/utils_longbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def qa_f1_score(prediction, ground_truth, **kwargs):

# Max length for NVIDIA GeForce RTX 3090 (24 GB)
model2maxlen = {
"meta-llama/Llama-2-7b-chat-hf": 3500,
"meta-llama/Llama-2-7b-chat-hf": 4096,
"meta-llama/Meta-Llama-3-8B-Instruct": 5000,
"meta-llama/Llama-3.1-8B-Instruct": 5000,
"meta-llama/Llama-3.1-8B-Instruct": 10000,
"microsoft/Phi-3-mini-4k-instruct": 4096,
'meta-llama/Llama-3.2-1B-Instruct': 10000,
'meta-llama/Llama-3.2-3B-Instruct': 10000,
}

dataset2maxlen = {
Expand Down Expand Up @@ -250,7 +252,7 @@ def preprocess_prompt(tokenizer, data_sample, subset, model_name):


def post_process_pred(pred, subset, model_name):
if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3-8B" in model_name:
if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3" in model_name:
pred = pred[:pred.find("assistant")]
elif subset == "samsum":
pred = pred[:pred.find("\nDialogue")]
Expand Down

0 comments on commit 6c96bd0

Please sign in to comment.