From 94ed8ebc1d492c1e7c28cd026799b69c2ecafbe4 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Fri, 15 Nov 2024 16:28:27 +0000 Subject: [PATCH 1/7] Add LongBench validation --- .../python_tests/test_cache_optimizations.py | 84 ++++++ tests/python_tests/utils_longbench.py | 259 ++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 tests/python_tests/utils_longbench.py diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 45704f9dc6..d42aea57e1 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -16,6 +16,7 @@ from transformers import AutoTokenizer from common import TESTS_ROOT +from utils_longbench import dataset2maxlen, evaluate, preprocess_prompt, post_process_pred def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]: @@ -68,6 +69,7 @@ class CacheOptTestStruct: SHORT_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=32, max_cache_size=96, aggregation_mode=AggregationMode.NORM_SUM) +LONGBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=64, max_cache_size=256, aggregation_mode=AggregationMode.NORM_SUM) @pytest.mark.precommit @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="doesn't work on win due to optimum-intel export bug, segfault on mac") @@ -145,3 +147,85 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t del model_cb_noopt +@pytest.fixture(scope='module') +def phi3_converted_model(tmp_path_factory): + model_id = "microsoft/Phi-3-mini-4k-instruct" + model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, load_in_8bit=False) + tokenizer = AutoTokenizer.from_pretrained(model_id) + models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id + model.save_pretrained(models_path) + ov_tokenizer, ov_detokenizer = convert_tokenizer(tokenizer, with_detokenizer=True, skip_special_tokens=True) + serialize(ov_tokenizer, models_path / "openvino_tokenizer.xml") + serialize(ov_detokenizer, models_path / "openvino_detokenizer.xml") + phi3_converted_model = ConvertedModel(model, tokenizer, models_path) + yield phi3_converted_model + del phi3_converted_model + del model + + +@pytest.mark.precommit +@pytest.mark.parametrize("subset", ["samsum", "qmsum", "trec", "qasper", "hotpotqa", "repobench-p"]) +def test_unoptimized_generation_longbench(phi3_converted_model, subset): + seqs_per_request = 2 + num_kv_blocks = 1000 + scheduler_config = get_scheduler_config(num_kv_blocks) + models_path = phi3_converted_model.models_path + model_name = "/".join(models_path.parts[-2:]) + max_new_tokens = dataset2maxlen[subset] + tokenizer = phi3_converted_model.tokenizer + + generation_config = GenerationConfig() # expecting default greedy sampling + generation_config.num_return_sequences = 1 + generation_config.max_new_tokens = max_new_tokens + generation_config.eos_token_id = tokenizer.eos_token_id + + data = datasets.load_dataset('THUDM/LongBench', subset, split='test') + + # model_id = "microsoft/Phi-3-mini-4k-instruct" + # model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, load_in_8bit=False) + + model_cb_noopt = ContinuousBatchingPipeline(models_path.absolute().as_posix(), scheduler_config, "CPU", {}) + with tqdm(total=len(data)) as progress_bar: + batch = [] + answers = [] + for p_idx, data_sample in enumerate(data): + prompt, context_len = preprocess_prompt(tokenizer, data_sample, subset, model_name) + progress_bar.update(1) + batch.append(prompt) + answers.append({"context_len": context_len, "answers": data_sample["answers"], "all_classes": data_sample["all_classes"]}) + + # input = tokenizer(prompt, truncation=False, return_tensors="pt") + # output = model.generate( + # **input, + # max_new_tokens=128, + # num_beams=1, + # do_sample=False, + # temperature=1.0, + # min_length=context_len+1, + # pad_token_id=tokenizer.eos_token_id, + # eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], + # )[0] + # pred = tokenizer.decode(output[context_len:], skip_special_tokens=True) + # pred = post_process_pred(output.m_generation_ids, subset, model_name) + # answers[-1]["pred"] = pred + + if ( + len(batch) == seqs_per_request + or p_idx == len(data) - 1 + ): + ans_batch = model_cb_noopt.generate( + batch, [generation_config] * len(batch) + ) + for i, output in enumerate(ans_batch, start=p_idx-len(batch)+1): + context_len = answers[i]["context_len"] + pred = post_process_pred(output.m_generation_ids, subset, model_name) + answers[i]["pred"] = pred + + batch.clear() + + score = evaluate(answers, subset) + print(f"Score: {score}") + + pipeline_noopt_metrics = model_cb_noopt.get_metrics() + print(f"No-opt cache usage: max {pipeline_noopt_metrics.max_cache_usage:.3f}, avg {pipeline_noopt_metrics.avg_cache_usage:.3f}") + del model_cb_noopt diff --git a/tests/python_tests/utils_longbench.py b/tests/python_tests/utils_longbench.py new file mode 100644 index 0000000000..2f0bcb4010 --- /dev/null +++ b/tests/python_tests/utils_longbench.py @@ -0,0 +1,259 @@ +import re +import string + +from collections import Counter +from rouge import Rouge + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def normalize_zh_answer(s): + """Lower text and remove punctuation, extra whitespace.""" + + def white_space_fix(text): + return "".join(text.split()) + + def remove_punc(text): + cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." + all_punctuation = set(string.punctuation + cn_punctuation) + return "".join(ch for ch in text if ch not in all_punctuation) + + def lower(text): + return text.lower() + + return white_space_fix(remove_punc(lower(s))) + +def count_score(prediction, ground_truth, **kwargs): + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_score(prediction, ground_truth, **kwargs): + pattern = r'Paragraph (\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_zh_score(prediction, ground_truth, **kwargs): + pattern = r'段落(\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def code_sim_score(prediction, ground_truth, **kwargs): + from fuzzywuzzy import fuzz + all_lines = prediction.lstrip('\n').split('\n') + prediction = "" + for line in all_lines: + if ('`' not in line) and ('#' not in line) and ('//' not in line): + prediction = line + break + return (fuzz.ratio(prediction, ground_truth) / 100) + +def classification_score(prediction, ground_truth, **kwargs): + em_match_list = [] + all_classes = kwargs["all_classes"] + for class_name in all_classes: + if class_name in prediction: + em_match_list.append(class_name) + for match_term in em_match_list: + if match_term in ground_truth and match_term != ground_truth: + em_match_list.remove(match_term) + if ground_truth in em_match_list: + score = (1.0 / len(em_match_list)) + else: + score = 0.0 + return score + +def rouge_score(prediction, ground_truth, **kwargs): + rouge = Rouge() + try: + scores = rouge.get_scores([prediction], [ground_truth], avg=True) + except: + return 0.0 + return scores["rouge-l"]["f"] + +def f1_score(prediction, ground_truth, **kwargs): + common = Counter(prediction) & Counter(ground_truth) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(ground_truth) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + +def qa_f1_score(prediction, ground_truth, **kwargs): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + return f1_score(prediction_tokens, ground_truth_tokens) + + +dataset2metric = { + "narrativeqa": qa_f1_score, + "qasper": qa_f1_score, + "multifieldqa_en": qa_f1_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "gov_report": rouge_score, + "qmsum": rouge_score, + "multi_news": rouge_score, + "trec": classification_score, + "triviaqa": qa_f1_score, + "samsum": rouge_score, + "lsht": classification_score, + "passage_retrieval_en": retrieval_score, + "passage_count": count_score, + "passage_retrieval_zh": retrieval_zh_score, + "lcc": code_sim_score, + "repobench-p": code_sim_score, +} + +# Max length for NVIDIA GeForce RTX 3090 (24 GB) +model2maxlen = { + "meta-llama/Llama-2-7b-chat-hf": 3500, + "meta-llama/Meta-Llama-3-8B-Instruct": 5000, + "meta-llama/Llama-3.1-8B-Instruct": 5000, + "microsoft/Phi-3-mini-4k-instruct": 5000, +} + +dataset2maxlen = { + "narrativeqa": 128, + "qasper": 128, + "multifieldqa_en": 64, + "multifieldqa_zh": 64, + "hotpotqa": 32, + "2wikimqa": 32, + "musique": 32, + "dureader": 128, + "gov_report": 512, + "qmsum": 512, + "multi_news": 512, + "vcsum": 512, + "trec": 64, + "triviaqa": 32, + "samsum": 128, + "lsht": 64, + "passage_count": 32, + "passage_retrieval_en": 32, + "passage_retrieval_zh": 32, + "lcc": 64, + "repobench-p": 64 +} + +dataset2prompt = { + "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", + "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", + "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", + "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", + "lcc": "Please complete the code given below. \n{context}Next line of code:\n", + "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" +} + + +def scorer(dataset, predictions, answers, all_classes): + total_score = 0. + for (prediction, ground_truths) in zip(predictions, answers): + score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] + for ground_truth in ground_truths: + score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) + total_score += score + return round(100 * total_score / len(predictions), 2) + + +def evaluate(model_output, task): + predictions, answers = [], [] + for data in model_output: + predictions.append(data["pred"]) + answers.append(data["answers"]) + all_classes = data["all_classes"] + score = scorer(task, predictions, answers, all_classes) + return score + + +def build_chat(prompt, model_name): + if "Llama-2" in model_name: + prompt = f"[INST]{prompt}[/INST]" + elif "Llama" in model_name: + prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" + elif "Phi-3" in model_name: + prompt = f"<|user|>\n{prompt} <|end|>\n<|assistant|>" + return prompt + + +def preprocess_prompt(tokenizer, data_sample, subset, model_name): + prompt_format = dataset2prompt[subset] + max_length = model2maxlen[model_name] + + prompt = prompt_format.format(**data_sample) + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + context_len = tokenized_prompt.shape[-1] + if len(tokenized_prompt) > max_length: + context_len = max_length + half = int(max_length/2) + prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + if subset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: + prompt = build_chat(prompt, model_name) + return prompt, context_len + + +def post_process_pred(subset, model_name): + if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3-8B" in model_name: + pred = pred[:pred.find("assistant")] + elif subset == "samsum": + pred = pred[:pred.find("\nDialogue")] + elif "Phi-3" in model_name and subset == "hotpotqa": + pred = pred.lstrip('\n').split('\n')[0] + return pred From 4df81fbf83e20b413ba217e65008cf2c4e953343 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Fri, 15 Nov 2024 18:14:56 +0000 Subject: [PATCH 2/7] fix --- tests/python_tests/test_cache_optimizations.py | 2 +- tests/python_tests/utils_longbench.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index d42aea57e1..8bf0c02c76 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -218,7 +218,7 @@ def test_unoptimized_generation_longbench(phi3_converted_model, subset): ) for i, output in enumerate(ans_batch, start=p_idx-len(batch)+1): context_len = answers[i]["context_len"] - pred = post_process_pred(output.m_generation_ids, subset, model_name) + pred = post_process_pred(output.m_generation_ids[0], subset, model_name) answers[i]["pred"] = pred batch.clear() diff --git a/tests/python_tests/utils_longbench.py b/tests/python_tests/utils_longbench.py index 2f0bcb4010..32f063bfaa 100644 --- a/tests/python_tests/utils_longbench.py +++ b/tests/python_tests/utils_longbench.py @@ -249,7 +249,7 @@ def preprocess_prompt(tokenizer, data_sample, subset, model_name): return prompt, context_len -def post_process_pred(subset, model_name): +def post_process_pred(pred, subset, model_name): if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3-8B" in model_name: pred = pred[:pred.find("assistant")] elif subset == "samsum": From 4d687430f67ca1dcf1dd5b0081a568b8ecef6b4a Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Mon, 18 Nov 2024 11:00:00 +0000 Subject: [PATCH 3/7] set context len as max len to phi3 --- tests/python_tests/test_cache_optimizations.py | 2 +- tests/python_tests/utils_longbench.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 8bf0c02c76..368489e42d 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -150,7 +150,7 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t @pytest.fixture(scope='module') def phi3_converted_model(tmp_path_factory): model_id = "microsoft/Phi-3-mini-4k-instruct" - model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, load_in_8bit=False) + model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id) models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id model.save_pretrained(models_path) diff --git a/tests/python_tests/utils_longbench.py b/tests/python_tests/utils_longbench.py index 32f063bfaa..77cf2efc4f 100644 --- a/tests/python_tests/utils_longbench.py +++ b/tests/python_tests/utils_longbench.py @@ -149,7 +149,7 @@ def qa_f1_score(prediction, ground_truth, **kwargs): "meta-llama/Llama-2-7b-chat-hf": 3500, "meta-llama/Meta-Llama-3-8B-Instruct": 5000, "meta-llama/Llama-3.1-8B-Instruct": 5000, - "microsoft/Phi-3-mini-4k-instruct": 5000, + "microsoft/Phi-3-mini-4k-instruct": 4096, } dataset2maxlen = { From 8831223dcc1ff15f8c074bda8d35411ef55049f6 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Mon, 18 Nov 2024 18:53:04 +0000 Subject: [PATCH 4/7] update model --- tests/python_tests/test_cache_optimizations.py | 4 ++-- tests/python_tests/utils_longbench.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 368489e42d..4e7aa4a486 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -149,7 +149,7 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t @pytest.fixture(scope='module') def phi3_converted_model(tmp_path_factory): - model_id = "microsoft/Phi-3-mini-4k-instruct" + model_id = "meta-llama/Llama-3.2-3B-Instruct" model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id) models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id @@ -166,7 +166,7 @@ def phi3_converted_model(tmp_path_factory): @pytest.mark.precommit @pytest.mark.parametrize("subset", ["samsum", "qmsum", "trec", "qasper", "hotpotqa", "repobench-p"]) def test_unoptimized_generation_longbench(phi3_converted_model, subset): - seqs_per_request = 2 + seqs_per_request = 32 num_kv_blocks = 1000 scheduler_config = get_scheduler_config(num_kv_blocks) models_path = phi3_converted_model.models_path diff --git a/tests/python_tests/utils_longbench.py b/tests/python_tests/utils_longbench.py index 77cf2efc4f..152095e1db 100644 --- a/tests/python_tests/utils_longbench.py +++ b/tests/python_tests/utils_longbench.py @@ -146,10 +146,12 @@ def qa_f1_score(prediction, ground_truth, **kwargs): # Max length for NVIDIA GeForce RTX 3090 (24 GB) model2maxlen = { - "meta-llama/Llama-2-7b-chat-hf": 3500, + "meta-llama/Llama-2-7b-chat-hf": 4096, "meta-llama/Meta-Llama-3-8B-Instruct": 5000, - "meta-llama/Llama-3.1-8B-Instruct": 5000, + "meta-llama/Llama-3.1-8B-Instruct": 10000, "microsoft/Phi-3-mini-4k-instruct": 4096, + 'meta-llama/Llama-3.2-1B-Instruct': 10000, + 'meta-llama/Llama-3.2-3B-Instruct': 10000, } dataset2maxlen = { @@ -250,7 +252,7 @@ def preprocess_prompt(tokenizer, data_sample, subset, model_name): def post_process_pred(pred, subset, model_name): - if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3-8B" in model_name: + if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3" in model_name: pred = pred[:pred.find("assistant")] elif subset == "samsum": pred = pred[:pred.find("\nDialogue")] From eefe4f2c79d0aa333c0aa28a6bf4d3fd1f8cca70 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Wed, 18 Dec 2024 20:31:35 +0000 Subject: [PATCH 5/7] change model to Qwen2 --- .../python_tests/test_cache_optimizations.py | 83 +++++++++---------- tests/python_tests/utils_longbench.py | 24 +----- 2 files changed, 44 insertions(+), 63 deletions(-) diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 4e7aa4a486..14ed864d41 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -69,7 +69,7 @@ class CacheOptTestStruct: SHORT_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=32, max_cache_size=96, aggregation_mode=AggregationMode.NORM_SUM) -LONGBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=64, max_cache_size=256, aggregation_mode=AggregationMode.NORM_SUM) +LONGBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=128, max_cache_size=672, aggregation_mode=AggregationMode.NORM_SUM) @pytest.mark.precommit @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="doesn't work on win due to optimum-intel export bug, segfault on mac") @@ -148,8 +148,8 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t @pytest.fixture(scope='module') -def phi3_converted_model(tmp_path_factory): - model_id = "meta-llama/Llama-3.2-3B-Instruct" +def qwen2_converted_model(tmp_path_factory): + model_id = "Qwen/Qwen2-0.5B-Instruct" model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id) models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id @@ -157,75 +157,72 @@ def phi3_converted_model(tmp_path_factory): ov_tokenizer, ov_detokenizer = convert_tokenizer(tokenizer, with_detokenizer=True, skip_special_tokens=True) serialize(ov_tokenizer, models_path / "openvino_tokenizer.xml") serialize(ov_detokenizer, models_path / "openvino_detokenizer.xml") - phi3_converted_model = ConvertedModel(model, tokenizer, models_path) - yield phi3_converted_model - del phi3_converted_model + qwen2_converted_model = ConvertedModel(model, tokenizer, models_path) + yield qwen2_converted_model + del qwen2_converted_model del model -@pytest.mark.precommit -@pytest.mark.parametrize("subset", ["samsum", "qmsum", "trec", "qasper", "hotpotqa", "repobench-p"]) -def test_unoptimized_generation_longbench(phi3_converted_model, subset): +@dataclass +class LongBenchTestData: + subset: str + ref_score: float + max_cache_usage: float + avg_cache_usage: float + + +@pytest.mark.parametrize("test_struct", [ + LongBenchTestData("samsum", 34.96, 16.2, 8.145), + LongBenchTestData("trec", 35, 14, 7.284), + LongBenchTestData("qasper", 14.67, 22.8, 13.182), +]) +def test_unoptimized_generation_longbench(qwen2_converted_model, test_struct): seqs_per_request = 32 num_kv_blocks = 1000 scheduler_config = get_scheduler_config(num_kv_blocks) - models_path = phi3_converted_model.models_path + models_path = qwen2_converted_model.models_path model_name = "/".join(models_path.parts[-2:]) + subset = test_struct.subset max_new_tokens = dataset2maxlen[subset] - tokenizer = phi3_converted_model.tokenizer + tokenizer = qwen2_converted_model.tokenizer generation_config = GenerationConfig() # expecting default greedy sampling generation_config.num_return_sequences = 1 generation_config.max_new_tokens = max_new_tokens generation_config.eos_token_id = tokenizer.eos_token_id - data = datasets.load_dataset('THUDM/LongBench', subset, split='test') + scheduler_config.use_cache_eviction = True + if scheduler_config.use_cache_eviction: + scheduler_config.cache_eviction_config = LONGBENCH_CACHE_EVICTION_CONFIG - # model_id = "microsoft/Phi-3-mini-4k-instruct" - # model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, load_in_8bit=False) + model_cb_opt = ContinuousBatchingPipeline(models_path.absolute().as_posix(), scheduler_config, "CPU", {}) + data = datasets.load_dataset('THUDM/LongBench', subset, split='test') - model_cb_noopt = ContinuousBatchingPipeline(models_path.absolute().as_posix(), scheduler_config, "CPU", {}) with tqdm(total=len(data)) as progress_bar: batch = [] answers = [] for p_idx, data_sample in enumerate(data): - prompt, context_len = preprocess_prompt(tokenizer, data_sample, subset, model_name) + prompt = preprocess_prompt(data_sample, subset, model_name) progress_bar.update(1) batch.append(prompt) - answers.append({"context_len": context_len, "answers": data_sample["answers"], "all_classes": data_sample["all_classes"]}) - - # input = tokenizer(prompt, truncation=False, return_tensors="pt") - # output = model.generate( - # **input, - # max_new_tokens=128, - # num_beams=1, - # do_sample=False, - # temperature=1.0, - # min_length=context_len+1, - # pad_token_id=tokenizer.eos_token_id, - # eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], - # )[0] - # pred = tokenizer.decode(output[context_len:], skip_special_tokens=True) - # pred = post_process_pred(output.m_generation_ids, subset, model_name) - # answers[-1]["pred"] = pred - - if ( - len(batch) == seqs_per_request - or p_idx == len(data) - 1 - ): - ans_batch = model_cb_noopt.generate( + answers.append({"answers": data_sample["answers"], "all_classes": data_sample["all_classes"]}) + + if len(batch) == seqs_per_request or p_idx == len(data) - 1: + ans_batch = model_cb_opt.generate( batch, [generation_config] * len(batch) ) for i, output in enumerate(ans_batch, start=p_idx-len(batch)+1): - context_len = answers[i]["context_len"] pred = post_process_pred(output.m_generation_ids[0], subset, model_name) answers[i]["pred"] = pred - batch.clear() score = evaluate(answers, subset) print(f"Score: {score}") - pipeline_noopt_metrics = model_cb_noopt.get_metrics() - print(f"No-opt cache usage: max {pipeline_noopt_metrics.max_cache_usage:.3f}, avg {pipeline_noopt_metrics.avg_cache_usage:.3f}") - del model_cb_noopt + pipeline_noopt_metrics = model_cb_opt.get_metrics() + print(f"Opt cache usage: max {pipeline_noopt_metrics.max_cache_usage:.3f}, avg {pipeline_noopt_metrics.avg_cache_usage:.3f}") + + assert abs(test_struct.ref_score - score) < 1 + assert abs(test_struct.max_cache_usage - pipeline_noopt_metrics.max_cache_usage) < 1 + assert abs(test_struct.avg_cache_usage - pipeline_noopt_metrics.avg_cache_usage) < 1 + del model_cb_opt diff --git a/tests/python_tests/utils_longbench.py b/tests/python_tests/utils_longbench.py index 152095e1db..be349f6602 100644 --- a/tests/python_tests/utils_longbench.py +++ b/tests/python_tests/utils_longbench.py @@ -144,16 +144,6 @@ def qa_f1_score(prediction, ground_truth, **kwargs): "repobench-p": code_sim_score, } -# Max length for NVIDIA GeForce RTX 3090 (24 GB) -model2maxlen = { - "meta-llama/Llama-2-7b-chat-hf": 4096, - "meta-llama/Meta-Llama-3-8B-Instruct": 5000, - "meta-llama/Llama-3.1-8B-Instruct": 10000, - "microsoft/Phi-3-mini-4k-instruct": 4096, - 'meta-llama/Llama-3.2-1B-Instruct': 10000, - 'meta-llama/Llama-3.2-3B-Instruct': 10000, -} - dataset2maxlen = { "narrativeqa": 128, "qasper": 128, @@ -235,20 +225,12 @@ def build_chat(prompt, model_name): return prompt -def preprocess_prompt(tokenizer, data_sample, subset, model_name): +def preprocess_prompt(data_sample, subset, model_name): prompt_format = dataset2prompt[subset] - max_length = model2maxlen[model_name] - prompt = prompt_format.format(**data_sample) - tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] - context_len = tokenized_prompt.shape[-1] - if len(tokenized_prompt) > max_length: - context_len = max_length - half = int(max_length/2) - prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) if subset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: prompt = build_chat(prompt, model_name) - return prompt, context_len + return prompt def post_process_pred(pred, subset, model_name): @@ -258,4 +240,6 @@ def post_process_pred(pred, subset, model_name): pred = pred[:pred.find("\nDialogue")] elif "Phi-3" in model_name and subset == "hotpotqa": pred = pred.lstrip('\n').split('\n')[0] + elif "Qwen" in model_name and subset == "qasper": + pred = pred.lstrip('\n').split('\n')[0] return pred From a589ac8392591c098aa123b6d0e66677cb6c00b2 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Thu, 19 Dec 2024 15:11:00 +0000 Subject: [PATCH 6/7] add test to precommit --- tests/python_tests/test_cache_optimizations.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 14ed864d41..59a69de4f2 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -5,7 +5,9 @@ import sys from typing import Dict, List, Optional +import datasets import pytest +from tqdm import tqdm from optimum.intel.openvino import OVModelForCausalLM @@ -171,10 +173,11 @@ class LongBenchTestData: avg_cache_usage: float +@pytest.mark.precommit @pytest.mark.parametrize("test_struct", [ - LongBenchTestData("samsum", 34.96, 16.2, 8.145), - LongBenchTestData("trec", 35, 14, 7.284), - LongBenchTestData("qasper", 14.67, 22.8, 13.182), + LongBenchTestData("samsum", 36.78, 14, 9.596), + LongBenchTestData("trec", 28.12, 11.8, 7.721), + LongBenchTestData("qasper", 21.68, 18.4, 12.706), ]) def test_unoptimized_generation_longbench(qwen2_converted_model, test_struct): seqs_per_request = 32 @@ -196,7 +199,7 @@ def test_unoptimized_generation_longbench(qwen2_converted_model, test_struct): scheduler_config.cache_eviction_config = LONGBENCH_CACHE_EVICTION_CONFIG model_cb_opt = ContinuousBatchingPipeline(models_path.absolute().as_posix(), scheduler_config, "CPU", {}) - data = datasets.load_dataset('THUDM/LongBench', subset, split='test') + data = datasets.load_dataset('THUDM/LongBench', subset, split=f'test[:{seqs_per_request}]') with tqdm(total=len(data)) as progress_bar: batch = [] From 000d2a49ee382f75d60c9fb62f6a8fcfcc028e98 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Thu, 19 Dec 2024 17:21:10 +0000 Subject: [PATCH 7/7] add missing requirement --- tests/python_tests/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt index 3dac3f8b00..2f082063e3 100644 --- a/tests/python_tests/requirements.txt +++ b/tests/python_tests/requirements.txt @@ -31,4 +31,5 @@ sacremoses # - openai/whisper-base librosa soundfile -datasets \ No newline at end of file +datasets +rouge \ No newline at end of file