From ee33b9346418b9eb716ad33172fccfb101ab60fc Mon Sep 17 00:00:00 2001 From: "Xu, Shuo" <100334393+ATMxsp01@users.noreply.github.com> Date: Wed, 18 Sep 2024 15:55:14 +0800 Subject: [PATCH] Longbench: NV code to ipex-llm (#11662) * add nv longbench * LongBench: NV code to ipex-llm * ammend * add more models support * ammend * optimize LongBench's user experience * ammend * ammend * fix typo * ammend * remove cuda related information & add a readme * add license to python scripts & polish the readme * ammend * ammend --------- Co-authored-by: cyita Co-authored-by: ATMxsp01 Co-authored-by: leonardozcm --- python/llm/dev/benchmark/LongBench/README.md | 124 ++++++++ .../llm/dev/benchmark/LongBench/config.yaml | 29 ++ .../config/ablation_c1024_w32_k7_maxpool.json | 7 + .../config/ablation_c2048_w32_k7_maxpool.json | 7 + .../config/ablation_c4096_w32_k7_maxpool.json | 7 + .../config/ablation_c512_w32_k7_maxpool.json | 7 + .../LongBench/config/dataset2maxlen.json | 23 ++ .../LongBench/config/dataset2prompt.json | 23 ++ .../LongBench/config/model2maxlen.json | 18 ++ .../LongBench/config/model2path.json | 18 ++ python/llm/dev/benchmark/LongBench/eval.py | 150 +++++++++ python/llm/dev/benchmark/LongBench/metrics.py | 165 ++++++++++ python/llm/dev/benchmark/LongBench/pred.py | 287 ++++++++++++++++++ .../dev/benchmark/LongBench/test_and_eval.sh | 7 + 14 files changed, 872 insertions(+) create mode 100644 python/llm/dev/benchmark/LongBench/README.md create mode 100644 python/llm/dev/benchmark/LongBench/config.yaml create mode 100644 python/llm/dev/benchmark/LongBench/config/ablation_c1024_w32_k7_maxpool.json create mode 100644 python/llm/dev/benchmark/LongBench/config/ablation_c2048_w32_k7_maxpool.json create mode 100644 python/llm/dev/benchmark/LongBench/config/ablation_c4096_w32_k7_maxpool.json create mode 100644 python/llm/dev/benchmark/LongBench/config/ablation_c512_w32_k7_maxpool.json create mode 100644 python/llm/dev/benchmark/LongBench/config/dataset2maxlen.json create mode 100644 python/llm/dev/benchmark/LongBench/config/dataset2prompt.json create mode 100644 python/llm/dev/benchmark/LongBench/config/model2maxlen.json create mode 100644 python/llm/dev/benchmark/LongBench/config/model2path.json create mode 100644 python/llm/dev/benchmark/LongBench/eval.py create mode 100644 python/llm/dev/benchmark/LongBench/metrics.py create mode 100644 python/llm/dev/benchmark/LongBench/pred.py create mode 100755 python/llm/dev/benchmark/LongBench/test_and_eval.sh diff --git a/python/llm/dev/benchmark/LongBench/README.md b/python/llm/dev/benchmark/LongBench/README.md new file mode 100644 index 00000000000..31356da6235 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/README.md @@ -0,0 +1,124 @@ +# LongBench Benchmark Test + +LongBench is the first benchmark for bilingual, multitask, and comprehensive assessment of long context understanding capabilities of large language models. This benchmark implementation is adapted from [THUDM/LongBench](https://github.com/THUDM/LongBench) and [SnapKV/experiments/LongBench](https://github.com/FasterDecoding/SnapKV/tree/main/experiments/LongBench). + + +## Environment Preparation + +Before running, make sure to have [ipex-llm](../../../../../README.md) installed. + +```bash +pip install omegaconf +pip install datasets +pip install jieba +pip install fuzzywuzzy +pip install rouge +``` + +### Load Data + +You can download and load the LongBench data through the Hugging Face datasets ([🤗 HF Repo](https://huggingface.co/datasets/THUDM/LongBench)): + +```python + +from datasets import load_dataset + +datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ + "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ + "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] + +for dataset in datasets: + data = load_dataset('THUDM/LongBench', dataset, split='test') + data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') + +``` + +## config + +### `config.yaml` + +Config YAML file has following format + +```yaml +# The name of the models you want to test +model_name: + # - "mistral-7B-instruct-v0.2" + - "llama2-7b-chat-4k" + # - "chatglm4-9b" + # - "qwen2-7b-instruct" + +# whether test the full-kv +full_kv: True +# Whether apply model optimization +optimize_model: True +# dtype of the model +dtype: 'fp16' +# low bit of the model +low_bit: 'sym_int4' +# whether or not to use the 'e' version of the datasets +e: False + +# the compress kv configs you want to test +compress_kv: + - "ablation_c512_w32_k7_maxpool" + - "ablation_c1024_w32_k7_maxpool" + +# the datasets you want to test +datasets: + - "multi_news" + - "qasper" + - "hotpotqa" + - "trec" + - "passage_count" + - "lcc" + # - "multifieldqa_zh" + # - "dureader" + # - "vcsum" + # - "lsht" + # - "passage_retrieval_zh" + +``` + +### The `config` dir + +Some json files is saved in the `config` dir. It can be divided into three kinds: about models, about datasets, and about compress-kv. + +#### About Models + +- `model2path.json`: This file saves the path to the models. + +- `model2maxlen.json`: This file saves the max length of the prompts of each model. + +#### About datasets + +- `dataset2maxlen.json`: The max length of the outputs of the models of each dataset. + +- `dataset2prompt.json`: The format of prompts of each dataset. + +#### About compress-kv + +The rest JSON files are compress-kv test configurations. + +## Run + +There are two python files for users' call. + +1. Configure the `config.yaml` and run `pred.py` and you can obtain the output of the model under `pred/` folder corresponding to the model name. + +2. Run the evaluation code `eval.py`, you can get the evaluation results on all datasets in `result.json`. + +> [!Note] +> +> To test the models and get the score in a row, please run `test_and_eval.sh` + +## Citation + +```bibtex +@article{bai2023longbench, + title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding}, + author={Bai, Yushi and Lv, Xin and Zhang, Jiajie and Lyu, Hongchang and Tang, Jiankai and Huang, Zhidian and Du, Zhengxiao and Liu, Xiao and Zeng, Aohan and Hou, Lei and Dong, Yuxiao and Tang, Jie and Li, Juanzi}, + journal={arXiv preprint arXiv:2308.14508}, + year={2023} +} + +``` \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config.yaml b/python/llm/dev/benchmark/LongBench/config.yaml new file mode 100644 index 00000000000..0d99fd780b9 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config.yaml @@ -0,0 +1,29 @@ +model_name: + # - "mistral-7B-instruct-v0.2" + - "llama2-7b-chat-4k" + # - "chatglm4-9b" + # - "qwen2-7b-instruct" + +full_kv: True +optimize_model: True +dtype: 'fp16' +low_bit: 'sym_int4' + +e: False + +compress_kv: + - "ablation_c512_w32_k7_maxpool" + - "ablation_c1024_w32_k7_maxpool" + +datasets: + - "multi_news" + - "qasper" + - "hotpotqa" + - "trec" + - "passage_count" + - "lcc" + # - "multifieldqa_zh" + # - "dureader" + # - "vcsum" + # - "lsht" + # - "passage_retrieval_zh" diff --git a/python/llm/dev/benchmark/LongBench/config/ablation_c1024_w32_k7_maxpool.json b/python/llm/dev/benchmark/LongBench/config/ablation_c1024_w32_k7_maxpool.json new file mode 100644 index 00000000000..1441e9091a6 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/ablation_c1024_w32_k7_maxpool.json @@ -0,0 +1,7 @@ +{ + "window_sizes": 32, + "default_max_capacity_prompts": 1024, + "specific_max_capcity_prompts": {}, + "kernel_sizes": 7, + "pooling": "maxpool" +} \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config/ablation_c2048_w32_k7_maxpool.json b/python/llm/dev/benchmark/LongBench/config/ablation_c2048_w32_k7_maxpool.json new file mode 100644 index 00000000000..8ec8ce60a95 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/ablation_c2048_w32_k7_maxpool.json @@ -0,0 +1,7 @@ +{ + "window_sizes": 32, + "default_max_capacity_prompts": 2048, + "specific_max_capcity_prompts": {}, + "kernel_sizes": 7, + "pooling": "maxpool" +} \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config/ablation_c4096_w32_k7_maxpool.json b/python/llm/dev/benchmark/LongBench/config/ablation_c4096_w32_k7_maxpool.json new file mode 100644 index 00000000000..692a3b25ef8 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/ablation_c4096_w32_k7_maxpool.json @@ -0,0 +1,7 @@ +{ + "window_sizes": 32, + "default_max_capacity_prompts": 4096, + "specific_max_capcity_prompts": {}, + "kernel_sizes": 7, + "pooling": "maxpool" +} \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config/ablation_c512_w32_k7_maxpool.json b/python/llm/dev/benchmark/LongBench/config/ablation_c512_w32_k7_maxpool.json new file mode 100644 index 00000000000..e342be84d43 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/ablation_c512_w32_k7_maxpool.json @@ -0,0 +1,7 @@ +{ + "window_sizes": 32, + "default_max_capacity_prompts": 512, + "specific_max_capcity_prompts": {}, + "kernel_sizes": 7, + "pooling": "maxpool" +} \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config/dataset2maxlen.json b/python/llm/dev/benchmark/LongBench/config/dataset2maxlen.json new file mode 100644 index 00000000000..79d0d9990e5 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/dataset2maxlen.json @@ -0,0 +1,23 @@ +{ + "narrativeqa": 128, + "qasper": 128, + "multifieldqa_en": 64, + "multifieldqa_zh": 64, + "hotpotqa": 32, + "2wikimqa": 32, + "musique": 32, + "dureader": 128, + "gov_report": 512, + "qmsum": 512, + "multi_news": 512, + "vcsum": 512, + "trec": 64, + "triviaqa": 32, + "samsum": 128, + "lsht": 64, + "passage_count": 32, + "passage_retrieval_en": 32, + "passage_retrieval_zh": 32, + "lcc": 64, + "repobench-p": 64 +} \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config/dataset2prompt.json b/python/llm/dev/benchmark/LongBench/config/dataset2prompt.json new file mode 100644 index 00000000000..1c85f6bc0f0 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/dataset2prompt.json @@ -0,0 +1,23 @@ +{ + "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", + "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", + "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", + "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", + "lcc": "Please complete the code given below. \n{context}Next line of code:\n", + "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" +} \ No newline at end of file diff --git a/python/llm/dev/benchmark/LongBench/config/model2maxlen.json b/python/llm/dev/benchmark/LongBench/config/model2maxlen.json new file mode 100644 index 00000000000..1238ed1f4f6 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/model2maxlen.json @@ -0,0 +1,18 @@ +{ + "llama2-7b-chat-4k": 4096, + "longchat-v1.5-7b-32k": 4096, + "xgen-7b-8k": 4096, + "internlm-7b-8k": 4096, + "chatglm2-6b": 4096, + "chatglm2-6b-32k": 4096, + "chatglm3-6b-32k": 4096, + "chatglm4-9b": 4096, + "vicuna-v1.5-7b-16k": 4096, + "mistral-7B-instruct-v0.2": 4096, + "mistral-7B-instruct-v0.1": 4096, + "mixtral-8x7B-instruct-v0.1": 4096, + "llama-2-7B-32k-instruct": 4096, + "lwm-text-chat-1m": 4096, + "lwm-text-1m": 4096, + "qwen2-7b-instruct": 4096 +} diff --git a/python/llm/dev/benchmark/LongBench/config/model2path.json b/python/llm/dev/benchmark/LongBench/config/model2path.json new file mode 100644 index 00000000000..1b7520ded22 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/config/model2path.json @@ -0,0 +1,18 @@ +{ + "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf", + "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k", + "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst", + "internlm-7b-8k": "internlm/internlm-chat-7b-8k", + "chatglm2-6b": "THUDM/chatglm2-6b", + "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", + "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", + "chatglm4-9b": "THUDM/glm-4-9b-chat", + "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k", + "mistral-7B-instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", + "mistral-7B-instruct-v0.1": "mistralai/Mistral-7B-Instruct-v0.1", + "mixtral-8x7B-instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "llama-2-7B-32k-instruct": "togethercomputer/Llama-2-7B-32K-Instruct", + "lwm-text-chat-1m": "LargeWorldModel/LWM-Text-Chat-1M", + "lwm-text-1m": "LargeWorldModel/LWM-Text-1M", + "qwen2-7b-instruct": "Qwen/Qwen2-7B-Instruct" +} diff --git a/python/llm/dev/benchmark/LongBench/eval.py b/python/llm/dev/benchmark/LongBench/eval.py new file mode 100644 index 00000000000..2df84651bc1 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/eval.py @@ -0,0 +1,150 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from +# https://github.com/THUDM/LongBench/blob/main/eval.py +# and +# https://github.com/FasterDecoding/SnapKV/blob/main/experiments/LongBench/eval.py + +import os +import json +import argparse +import numpy as np + +current_dir = os.path.dirname(os.path.realpath(__file__)) + +from metrics import ( + qa_f1_score, + rouge_zh_score, + qa_f1_zh_score, + rouge_score, + classification_score, + retrieval_score, + retrieval_zh_score, + count_score, + code_sim_score, +) + +dataset2metric = { + "narrativeqa": qa_f1_score, + "qasper": qa_f1_score, + "multifieldqa_en": qa_f1_score, + "multifieldqa_zh": qa_f1_zh_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "dureader": rouge_zh_score, + "gov_report": rouge_score, + "qmsum": rouge_score, + "multi_news": rouge_score, + "vcsum": rouge_zh_score, + "trec": classification_score, + "triviaqa": qa_f1_score, + "samsum": rouge_score, + "lsht": classification_score, + "passage_retrieval_en": retrieval_score, + "passage_count": count_score, + "passage_retrieval_zh": retrieval_zh_score, + "lcc": code_sim_score, + "repobench-p": code_sim_score, +} + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default=None) + parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") + return parser.parse_args(args) + +def scorer_e(dataset, predictions, answers, lengths, all_classes): + scores = {"0-4k": [], "4-8k": [], "8k+": []} + for (prediction, ground_truths, length) in zip(predictions, answers, lengths): + score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] + for ground_truth in ground_truths: + score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) + if length < 4000: + scores["0-4k"].append(score) + elif length < 8000: + scores["4-8k"].append(score) + else: + scores["8k+"].append(score) + for key in scores.keys(): + scores[key] = round(100 * np.mean(scores[key]), 2) + return scores + +def scorer(dataset, predictions, answers, all_classes): + total_score = 0. + for (prediction, ground_truths) in zip(predictions, answers): + score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] + for ground_truth in ground_truths: + score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) + total_score += score + return round(100 * total_score / len(predictions), 2) + + +def result_path_range(full_kv: bool, configs: list[str], model_name: str, fa_name: str): + if full_kv: + yield f"{fa_name}/{model_name}" + + for config in configs: + yield f"{fa_name}/{model_name}_{config}" + + +if __name__ == '__main__': + from omegaconf import OmegaConf + conf = OmegaConf.load(f'{current_dir}/config.yaml') + + model_names = conf['model_name'] if OmegaConf.is_list(conf['model_name']) else [conf['model_name']] + full_kv = conf['full_kv'] + ees = conf['e'] if OmegaConf.is_list(conf['e']) else [conf['e']] + compresskv_configs = conf['compress_kv'] if OmegaConf.is_list(conf['compress_kv']) else [conf['compress_kv']] + + model2maxlen = json.load(open(f"{current_dir}/config/model2maxlen.json", "r")) + + for model_name in model_names: + max_length = model2maxlen[model_name] + for e in ees: + fa_dir_name = f"pred_{'e_' if e else ''}{max_length}" + for path in result_path_range(full_kv, compresskv_configs, model_name, fa_dir_name): + scores = dict() + all_files = os.listdir(path) + print("Evaluating on:", all_files) + for filename in all_files: + if not filename.endswith("jsonl"): + continue + predictions, answers, lengths = [], [], [] + dataset = filename.split('.')[0] + with open(f"{path}/{filename}", "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + predictions.append(data["pred"]) + answers.append(data["answers"]) + all_classes = data["all_classes"] + if "length" in data: + lengths.append(data["length"]) + if e: + score = scorer_e(dataset, predictions, answers, lengths, all_classes) + else: + score = scorer(dataset, predictions, answers, all_classes) + if dataset == 'qasper': + score_e = scorer_e(dataset, predictions, answers, lengths, all_classes) + scores[dataset] = score + + out_path = f"{path}/result.json" + with open(out_path, "w") as f: + json.dump(scores, f, ensure_ascii=False, indent=4) diff --git a/python/llm/dev/benchmark/LongBench/metrics.py b/python/llm/dev/benchmark/LongBench/metrics.py new file mode 100644 index 00000000000..366101a0944 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/metrics.py @@ -0,0 +1,165 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from +# https://github.com/THUDM/LongBench/blob/main/metrics.py +# and +# https://github.com/FasterDecoding/SnapKV/blob/main/experiments/LongBench/metrics.py + + +import re +import string + +import jieba +from fuzzywuzzy import fuzz +import difflib + +from typing import List +from collections import Counter +from rouge import Rouge + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def normalize_zh_answer(s): + """Lower text and remove punctuation, extra whitespace.""" + + def white_space_fix(text): + return "".join(text.split()) + + def remove_punc(text): + cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." + all_punctuation = set(string.punctuation + cn_punctuation) + return "".join(ch for ch in text if ch not in all_punctuation) + + def lower(text): + return text.lower() + + return white_space_fix(remove_punc(lower(s))) + +def count_score(prediction, ground_truth, **kwargs): + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_score(prediction, ground_truth, **kwargs): + pattern = r'Paragraph (\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_zh_score(prediction, ground_truth, **kwargs): + pattern = r'段落(\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def code_sim_score(prediction, ground_truth, **kwargs): + all_lines = prediction.lstrip('\n').split('\n') + prediction = "" + for line in all_lines: + if ('`' not in line) and ('#' not in line) and ('//' not in line): + prediction = line + break + return (fuzz.ratio(prediction, ground_truth) / 100) + +def classification_score(prediction, ground_truth, **kwargs): + em_match_list = [] + all_classes = kwargs["all_classes"] + for class_name in all_classes: + if class_name in prediction: + em_match_list.append(class_name) + for match_term in em_match_list: + if match_term in ground_truth and match_term != ground_truth: + em_match_list.remove(match_term) + if ground_truth in em_match_list: + score = (1.0 / len(em_match_list)) + else: + score = 0.0 + return score + +def rouge_score(prediction, ground_truth, **kwargs): + rouge = Rouge() + try: + scores = rouge.get_scores([prediction], [ground_truth], avg=True) + except: + return 0.0 + return scores["rouge-l"]["f"] + +def rouge_zh_score(prediction, ground_truth, **kwargs): + prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) + ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) + score = rouge_score(prediction, ground_truth) + return score + +def f1_score(prediction, ground_truth, **kwargs): + common = Counter(prediction) & Counter(ground_truth) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(ground_truth) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + +def qa_f1_score(prediction, ground_truth, **kwargs): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + return f1_score(prediction_tokens, ground_truth_tokens) + + +def qa_f1_zh_score(prediction, ground_truth, **kwargs): + prediction_tokens = list(jieba.cut(prediction, cut_all=False)) + ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) + prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] + ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] + prediction_tokens = [token for token in prediction_tokens if len(token) > 0] + ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] + return f1_score(prediction_tokens, ground_truth_tokens) diff --git a/python/llm/dev/benchmark/LongBench/pred.py b/python/llm/dev/benchmark/LongBench/pred.py new file mode 100644 index 00000000000..c95ccdbd3eb --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/pred.py @@ -0,0 +1,287 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from +# https://github.com/THUDM/LongBench/blob/main/pred.py +# and +# https://github.com/FasterDecoding/SnapKV/blob/main/experiments/LongBench/pred_snap.py + + +import os +from transformers import AutoTokenizer +from ipex_llm.transformers import AutoModelForCausalLM +from datasets import load_dataset +import json +from tqdm import tqdm +import numpy as np +import random +import argparse +import torch + +current_dir = os.path.dirname(os.path.realpath(__file__)) + +valid_model_names = [ + "llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", + "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "chatglm3-6b-32k", "vicuna-v1.5-7b-16k", + "mistral-7B-instruct-v0.2", "mistral-7B-instruct-v0.1", "llama-2-7B-32k-instruct", "mixtral-8x7B-instruct-v0.1","lwm-text-chat-1m", "lwm-text-1m", + "qwen2-7b-instruct", "chatglm4-9b"] + +valid_datasets_e = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \ + "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] + +valid_datasets = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", \ + "gov_report", "qmsum", "multi_news", "trec", "triviaqa", "samsum", \ + "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] + \ + ["multifieldqa_zh", "dureader", "vcsum", "lsht", "passage_retrieval_zh"] + +valid_dtypes = ['fp16', 'fp32'] + + +# This is the customized building prompt for chat models +def build_chat(tokenizer, prompt, model_name): + if "chatglm3" in model_name: + print('chatglm3') + prompt = tokenizer.build_chat_input(prompt) + elif "chatglm2" in model_name: + print('chatglm2') + prompt = tokenizer.build_prompt(prompt) + elif "longchat" in model_name or "vicuna" in model_name: + print('longchat') + from fastchat.model import get_conversation_template + conv = get_conversation_template("vicuna") + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + elif "llama2" in model_name or "llama-2" in model_name or "lwm" in model_name: + print('llama2', model_name) + prompt = f"[INST]{prompt}[/INST]" + elif "xgen" in model_name: + print('xgen') + header = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" + ) + prompt = header + f" ### Human: {prompt}\n###" + elif "internlm" in model_name: + print('internlm') + prompt = f"<|User|>:{prompt}\n<|Bot|>:" + elif "mistral" in model_name or "mixtral" in model_name: + print('mistral') + prompt = prompt + return prompt + +def post_process(response, model_name): + if "xgen" in model_name: + response = response.strip().replace("Assistant:", "") + elif "internlm" in model_name: + response = response.split("")[0] + return response + +@torch.inference_mode() +def get_pred_single_gpu(data, max_length, max_gen, + prompt_format, dataset, model_name, + model2path, out_path, low_bit, dtype, optimize_model, + compress=False, + window_sizes = None, + default_max_capacity_prompts = None, + specific_max_capcity_prompts = None, + kernel_sizes = None, + pooling = None): + + model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device = "xpu", dtype_=dtype, low_bit=low_bit, optimize_model=optimize_model) + device = model.device + print(f"model_device: {model.device}") + printed = False + print(out_path) + count_prompt_under_maxlen = 0 + for json_obj in tqdm(data): + ############################################################################################################ + # load compress args + count_prompt_under_maxlen += 1 + if compress: + inner_model = model.model if hasattr(model, "model") else model.base_model.encoder + layers = len(inner_model.layers) + # check if window_sizes is a list + if not isinstance(window_sizes, list): + window_sizes = [window_sizes] * layers + max_capacity_prompts = [default_max_capacity_prompts] * layers + if specific_max_capcity_prompts is not None: + for key, value in specific_max_capcity_prompts.items(): + max_capacity_prompts[key] = value + if not isinstance(kernel_sizes, list): + kernel_sizes = [kernel_sizes] * layers + from transformers.configuration_utils import PretrainedConfig + for i in range(layers): + cur_layer = inner_model.layers[i] + cur_layer_attn = cur_layer.self_attn if hasattr(cur_layer, "self_attn") else cur_layer.self_attention + cur_layer_attn.config = cur_layer_attn.config if hasattr(cur_layer_attn, "config") else PretrainedConfig() + + cur_layer_attn.config.window_size = window_sizes[i] + cur_layer_attn.config.max_capacity_prompt = max_capacity_prompts[i] + cur_layer_attn.config.kernel_size = kernel_sizes[i] + cur_layer_attn.config.pooling = pooling + ############################################################################################################ + + prompt = prompt_format.format(**json_obj) + # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + if "chatglm3" in model_name: + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0] + #print(f'initial len = {tokenized_prompt.shape}') + if len(tokenized_prompt) > max_length: + count_prompt_under_maxlen -= 1 + half = int(max_length/2) + prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks + prompt = build_chat(tokenizer, prompt, model_name) + if "chatglm3" in model_name: + input = prompt.to(device) + else: + input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) + context_length = input.input_ids.shape[-1] + print(f'context_length = {context_length}') + if not printed: + print(prompt) + printed = True + if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue + output = model.generate( + **input, + max_new_tokens=max_gen, + num_beams=1, + do_sample=False, + temperature=1.0, + min_length=context_length+1, + eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], + )[0] + else: + output = model.generate( + **input, + max_new_tokens=max_gen, + num_beams=1, + do_sample=False, + temperature=1.0, + min_length=context_length+1, + )[0] + pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) + pred = post_process(pred, model_name) + with open(out_path, "a", encoding="utf-8") as f: + json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}, f, ensure_ascii=False) + f.write('\n') + + count_out_path = os.path.join(os.path.split(out_path)[0], "uncut_prompt_count.json") + prompt_count_result = {} + if os.path.isfile(count_out_path): + with open(count_out_path, "r", encoding = "utf-8") as f: + prompt_count_result = json.load(f) + prompt_count_result[dataset] = count_prompt_under_maxlen + with open(count_out_path, "w", encoding = "utf-8") as f: + json.dump(prompt_count_result, f, ensure_ascii=False, indent=4) + + + +def seed_everything(seed): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def load_model_and_tokenizer(path, model_name, device, dtype_, low_bit, optimize_model): + if (dtype_ == 'fp32'): + dtype = torch.float32 + elif (dtype_ == 'fp16'): + dtype = torch.float16 + else: + raise ValueError(f"dtype {dtype_} is not supported") + model = AutoModelForCausalLM.from_pretrained( + path, + optimize_model=optimize_model, + load_in_low_bit=low_bit, + use_cache=True, + trust_remote_code=True, + torch_dtype = dtype + ).to(device) + tokenizer = AutoTokenizer.from_pretrained( + path, + padding_side="left", + use_fast=False, + trust_remote_code=True, + ) + model = model.half().to(device) + return model, tokenizer + +def compresskv_config_range(full_kv: bool, configs: list[str], model_name: str): + if full_kv: + os.environ["IPEX_LLM_COMPRESS_KV_CACHE"] = "0" + yield False, {}, model_name + + os.environ["IPEX_LLM_COMPRESS_KV_CACHE"] = "1" + for config in configs: + yield True, json.load(open(os.path.join(f'{current_dir}/config', f"{config}.json"), "r")), f"{model_name}_{config}" + + +if __name__ == '__main__': + seed_everything(42) + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + from omegaconf import OmegaConf + conf = OmegaConf.load(f'{current_dir}/config.yaml') + + model_names = conf['model_name'] if OmegaConf.is_list(conf['model_name']) else [conf['model_name']] + full_kv = conf['full_kv'] + e = conf['e'] + compresskv_configs = conf['compress_kv'] if OmegaConf.is_list(conf['compress_kv']) else [conf['compress_kv']] + datasets = conf['datasets'] if OmegaConf.is_list(conf['datasets']) else [conf['datasets']] + dtype = conf['dtype'] + low_bit = conf['low_bit'] + optimize_model = conf['optimize_model'] + + model2path = json.load(open(f"{current_dir}/config/model2path.json", "r")) + model2maxlen = json.load(open(f"{current_dir}/config/model2maxlen.json", "r")) + + dataset2prompt = json.load(open(f"{current_dir}/config/dataset2prompt.json", "r")) + dataset2maxlen = json.load(open(f"{current_dir}/config/dataset2maxlen.json", "r")) + + ## check + for model_name in model_names: + if model_name not in valid_model_names: + raise ValueError(f"model {model_name} is not supported") + if e not in [True, False]: + raise ValueError("e should be True or False") + for dataset in datasets: + if e: + valid_dataset_check = valid_datasets_e + else: + valid_dataset_check = valid_datasets + # check if args dataset in datasets + if dataset not in valid_dataset_check: + raise ValueError(f"Dataset {dataset} not found in datasets") + if dtype not in valid_dtypes: + raise ValueError(f"dtype {dtype} is not supported") + + for model_name in model_names: + max_length = model2maxlen[model_name] + for compress, compress_args, write_model_name in compresskv_config_range(full_kv, compresskv_configs, model_name): + for dataset in datasets: + e_string = "_e" if e else "" + data = load_dataset('THUDM/LongBench', f"{dataset}{e_string}", split='test') + + if not os.path.exists(f"{current_dir}/pred{e_string}_{max_length}"): + os.makedirs(f"{current_dir}/pred{e_string}_{max_length}") + if not os.path.exists(f"{current_dir}/pred{e_string}_{max_length}/{write_model_name}"): + os.makedirs(f"{current_dir}/pred{e_string}_{max_length}/{write_model_name}") + out_path = f"{current_dir}/pred{e_string}_{max_length}/{write_model_name}/{dataset}.jsonl" + + prompt_format = dataset2prompt[dataset] + max_gen = dataset2maxlen[dataset] + data_all = [data_sample for data_sample in data] + get_pred_single_gpu(data_all, max_length, max_gen, prompt_format, dataset, model_name, model2path, out_path, low_bit, dtype, compress, optimize_model, **compress_args) diff --git a/python/llm/dev/benchmark/LongBench/test_and_eval.sh b/python/llm/dev/benchmark/LongBench/test_and_eval.sh new file mode 100755 index 00000000000..026852fa334 --- /dev/null +++ b/python/llm/dev/benchmark/LongBench/test_and_eval.sh @@ -0,0 +1,7 @@ +#! /bin/sh + +export HF_ENDPOINT=https://hf-mirror.com + +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +python ${SHELL_FOLDER}/pred.py +python ${SHELL_FOLDER}/eval.py \ No newline at end of file