From eecf70fb61553e430ef34e1067536a3e645d56eb Mon Sep 17 00:00:00 2001 From: Alexander Kozlov Date: Sat, 21 Sep 2024 09:47:04 +0400 Subject: [PATCH] Added support of HF and GenAI models into CLI (#887) --- .github/workflows/llm_bench-python.yml | 10 ++- llm_bench/python/who_what_benchmark/README.md | 25 ++++-- .../who_what_benchmark/requirements.txt | 6 +- .../who_what_benchmark/tests/test_cli.py | 46 +++++++++-- .../whowhatbench/evaluator.py | 5 +- .../who_what_benchmark/whowhatbench/wwb.py | 78 +++++++++++++++---- 6 files changed, 135 insertions(+), 35 deletions(-) diff --git a/.github/workflows/llm_bench-python.yml b/.github/workflows/llm_bench-python.yml index 083847ce68..45e6dc2941 100644 --- a/.github/workflows/llm_bench-python.yml +++ b/.github/workflows/llm_bench-python.yml @@ -40,7 +40,8 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest black GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ${{ env.LLM_BENCH_PYPATH }}/requirements.txt - pip install openvino-nightly + python -m pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url +https://storage.openvinotoolkit.org/simple/wheels/nightly GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ${{ env.WWB_PATH }}/requirements.txt GIT_CLONE_PROTECTION_ACTIVE=false pip install ${{ env.WWB_PATH }} @@ -73,7 +74,7 @@ jobs: python ./llm_bench/python/benchmark.py -m ./ov_models/tiny-sd/pytorch/dldt/FP16/ -pf ./llm_bench/python/prompts/stable-diffusion.jsonl -d cpu -n 1 - name: WWB Tests run: | - python -m pytest ./llm_bench/python/who_what_benchmark/tests + python -m pytest llm_bench/python/who_what_benchmark/tests stateful: runs-on: ubuntu-20.04 steps: @@ -85,7 +86,8 @@ jobs: run: | GIT_CLONE_PROTECTION_ACTIVE=false python -m pip install -r llm_bench/python/requirements.txt python -m pip uninstall --yes openvino - python -m pip install openvino-nightly + python -m pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url +https://storage.openvinotoolkit.org/simple/wheels/nightly python llm_bench/python/convert.py --model_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir . --stateful grep beam_idx pytorch/dldt/FP32/openvino_model.xml - name: WWB Tests @@ -93,4 +95,4 @@ jobs: GIT_CLONE_PROTECTION_ACTIVE=false pip install -r llm_bench/python/who_what_benchmark/requirements.txt GIT_CLONE_PROTECTION_ACTIVE=false pip install llm_bench/python/who_what_benchmark/ pip install pytest - python -m pytest llm_bench/python/who_what_benchmark/tests + python -m pytest llm_bench/python/who_what_benchmark/tests diff --git a/llm_bench/python/who_what_benchmark/README.md b/llm_bench/python/who_what_benchmark/README.md index 008a5b92f2..d140b0af75 100644 --- a/llm_bench/python/who_what_benchmark/README.md +++ b/llm_bench/python/who_what_benchmark/README.md @@ -55,27 +55,36 @@ metrics_per_prompt, metrics = evaluator.score(optimized_model, test_data=prompts ```sh wwb --help -# run ground truth generation for uncompressed model on the first 32 samples from squad dataset -# ground truth will be saved in llama_2_7b_squad_gt.csv file +# Run ground truth generation for uncompressed model on the first 32 samples from squad dataset +# Ground truth will be saved in llama_2_7b_squad_gt.csv file wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_squad_gt.csv --dataset squad --split validation[:32] --dataset-field question -# run comparison with compressed model on the first 32 samples from squad dataset +# Run comparison with compressed model on the first 32 samples from squad dataset wwb --target-model /home/user/models/Llama_2_7b_chat_hf_int8 --gt-data llama_2_7b_squad_gt.csv --dataset squad --split validation[:32] --dataset-field question -# output will be like this +# Output will be like this # similarity FDT SDT FDT norm SDT norm # 0 0.972823 67.296296 20.592593 0.735127 0.151505 -# run ground truth generation for uncompressed model on internal set of questions -# ground truth will be saved in llama_2_7b_squad_gt.csv file +# Run ground truth generation for uncompressed model on internal set of questions +# Ground truth will be saved in llama_2_7b_squad_gt.csv file wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv -# run comparison with compressed model on internal set of questions +# Run comparison with compressed model on internal set of questions wwb --target-model /home/user/models/Llama_2_7b_chat_hf_int8 --gt-data llama_2_7b_wwb_gt.csv -## Control the number of samples and use verbose mode to see the difference in the results +# Use --num-samples to control the number of samples wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv --num-samples 10 + +# Use -v for verbose mode to see the difference in the results wwb --target-model /home/user/models/Llama_2_7b_chat_hf_int8 --gt-data llama_2_7b_wwb_gt.csv --num-samples 10 -v + +# Use --hf AutoModelForCausalLM to instantiate the model from model_id/folder +wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv --hf + +# Use --language parameter to control the language of promts +# Autodetection works for basic Chinese models +wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv --hf ``` ### Supported metrics diff --git a/llm_bench/python/who_what_benchmark/requirements.txt b/llm_bench/python/who_what_benchmark/requirements.txt index 9d413a897f..7fa229f370 100644 --- a/llm_bench/python/who_what_benchmark/requirements.txt +++ b/llm_bench/python/who_what_benchmark/requirements.txt @@ -1,8 +1,10 @@ transformers>=4.35.2 sentence-transformers>=2.2.2 -openvino>=2023.3.0 -openvino-telemetry>=2023.2.1 +openvino>=2024.3.0 +openvino-telemetry>=2024.3.0 optimum-intel>=1.14 +openvino-tokenizers>=2024.3.0 +openvino-genai>=2024.3.0 pandas>=2.0.3 numpy>=1.23.5 tqdm>=4.66.1 diff --git a/llm_bench/python/who_what_benchmark/tests/test_cli.py b/llm_bench/python/who_what_benchmark/tests/test_cli.py index ecd23fbaa3..8110e98335 100644 --- a/llm_bench/python/who_what_benchmark/tests/test_cli.py +++ b/llm_bench/python/who_what_benchmark/tests/test_cli.py @@ -32,17 +32,21 @@ def run_wwb(args): def setup_module(): + from optimum.exporters.openvino.convert import export_tokenizer + logger.info("Create models") tokenizer = AutoTokenizer.from_pretrained(model_id) base_model = OVModelForCausalLM.from_pretrained(model_id) base_model.save_pretrained(base_model_path) tokenizer.save_pretrained(base_model_path) + export_tokenizer(tokenizer, base_model_path) target_model = OVModelForCausalLM.from_pretrained( model_id, quantization_config=OVWeightQuantizationConfig(bits=8) ) target_model.save_pretrained(target_model_path) tokenizer.save_pretrained(target_model_path) + export_tokenizer(tokenizer, target_model_path) def teardown_module(): @@ -57,9 +61,10 @@ def test_target_model(): "--num-samples", "2", "--device", "CPU" ]) + assert result.returncode == 0 - assert "Metrics for model" in result.stdout - assert "## Reference text" not in result.stdout + assert "Metrics for model" in result.stderr + assert "## Reference text" not in result.stderr @pytest.fixture @@ -76,8 +81,6 @@ def test_gt_data(): "--num-samples", "2", "--device", "CPU" ]) - import time - time.sleep(1) data = pd.read_csv(temp_file_name) os.remove(temp_file_name) @@ -95,7 +98,7 @@ def test_output_directory(): "--output", temp_dir ]) assert result.returncode == 0 - assert "Metrics for model" in result.stdout + assert "Metrics for model" in result.stderr assert os.path.exists(os.path.join(temp_dir, "metrics_per_qustion.csv")) assert os.path.exists(os.path.join(temp_dir, "metrics.csv")) @@ -109,7 +112,7 @@ def test_verbose(): "--verbose" ]) assert result.returncode == 0 - assert "## Reference text" in result.stdout + assert "## Diff " in result.stderr def test_language_autodetect(): @@ -127,3 +130,34 @@ def test_language_autodetect(): assert result.returncode == 0 assert "马克" in data["questions"].values[0] + + +def test_hf_model(): + with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile: + temp_file_name = tmpfile.name + + result = run_wwb([ + "--base-model", model_id, + "--gt-data", temp_file_name, + "--num-samples", "2", + "--device", "CPU", + "--hf" + ]) + data = pd.read_csv(temp_file_name) + os.remove(temp_file_name) + + assert result.returncode == 0 + assert len(data["questions"].values) == 2 + + +def test_genai_model(): + result = run_wwb([ + "--base-model", base_model_path, + "--target-model", target_model_path, + "--num-samples", "2", + "--device", "CPU", + "--genai" + ]) + assert result.returncode == 0 + assert "Metrics for model" in result.stderr + assert "## Reference text" not in result.stderr diff --git a/llm_bench/python/who_what_benchmark/whowhatbench/evaluator.py b/llm_bench/python/who_what_benchmark/whowhatbench/evaluator.py index 47936f16da..ac64167727 100644 --- a/llm_bench/python/who_what_benchmark/whowhatbench/evaluator.py +++ b/llm_bench/python/who_what_benchmark/whowhatbench/evaluator.py @@ -96,7 +96,8 @@ def __init__( max_new_tokens=128, crop_question=True, num_samples=None, - language=None + language=None, + gen_answer_fn=None, ) -> None: assert ( base_model is not None or gt_data is not None @@ -116,7 +117,7 @@ def __init__( self.language = autodetect_language(base_model) if base_model: - self.gt_data = self._generate_data(base_model) + self.gt_data = self._generate_data(base_model, gen_answer_fn) else: self.gt_data = pd.read_csv(gt_data, keep_default_na=False) diff --git a/llm_bench/python/who_what_benchmark/whowhatbench/wwb.py b/llm_bench/python/who_what_benchmark/whowhatbench/wwb.py index 7f5c3c26fc..8efca22059 100644 --- a/llm_bench/python/who_what_benchmark/whowhatbench/wwb.py +++ b/llm_bench/python/who_what_benchmark/whowhatbench/wwb.py @@ -1,17 +1,21 @@ import argparse import difflib import os - import json import pandas as pd +import logging from datasets import load_dataset from optimum.exporters import TasksManager from optimum.intel.openvino import OVModelForCausalLM from optimum.utils import NormalizedConfigManager, NormalizedTextConfig -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from . import Evaluator +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + TasksManager._SUPPORTED_MODEL_TYPE["stablelm-epoch"] = TasksManager._SUPPORTED_MODEL_TYPE["llama"] NormalizedConfigManager._conf["stablelm-epoch"] = NormalizedTextConfig.with_args( num_layers="num_hidden_layers", @@ -19,7 +23,39 @@ ) -def load_model(model_id, device="CPU", ov_config=None): +class GenAIModelWrapper(): + """ + A helper class to store additional attributes for GenAI models + """ + def __init__(self, model, model_dir): + self.model = model + self.config = AutoConfig.from_pretrained(model_dir) + + def __getattr__(self, attr): + if attr in self.__dict__: + return getattr(self, attr) + else: + return getattr(self.model, attr) + + +def load_genai_pipeline(model_dir, device="CPU"): + try: + import openvino_genai + except ImportError: + logger.error("Failed to import openvino_genai package. Please install it.") + exit(-1) + logger.info("Using OpenVINO GenAI API") + return GenAIModelWrapper(openvino_genai.LLMPipeline(model_dir, device), model_dir) + + +def load_model(model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False): + if use_hf: + logger.info("Using HF Transformers API") + return AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map=device.lower()) + + if use_genai: + return load_genai_pipeline(model_id, device) + if ov_config: with open(ov_config) as f: ov_options = json.load(f) @@ -157,6 +193,16 @@ def parse_args(): default=None, help="Used to select default prompts based on the primary model language, e.g. 'en', 'ch'.", ) + parser.add_argument( + "--hf", + action="store_true", + help="Use AutoModelForCausalLM from transformers library to instantiate the model.", + ) + parser.add_argument( + "--genai", + action="store_true", + help="Use LLMPipeline from transformers library to instantiate the model.", + ) return parser.parse_args() @@ -211,6 +257,11 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str: return "".join(output) +def genai_gen_answer(model, tokenizer, question, max_new_tokens, skip_question): + out = model.generate(question, max_new_tokens=max_new_tokens) + return out + + def main(): args = parse_args() check_args(args) @@ -228,7 +279,7 @@ def main(): language=args.language, ) else: - base_model = load_model(args.base_model, args.device, args.ov_config) + base_model = load_model(args.base_model, args.device, args.ov_config, args.hf, args.genai) evaluator = Evaluator( base_model=base_model, test_data=prompts, @@ -236,16 +287,17 @@ def main(): similarity_model_id=args.text_encoder, num_samples=args.num_samples, language=args.language, + gen_answer_fn=genai_gen_answer if args.genai else None ) if args.gt_data: evaluator.dump_gt(args.gt_data) del base_model if args.target_model: - target_model = load_model(args.target_model, args.device, args.ov_config) - all_metrics_per_question, all_metrics = evaluator.score(target_model) - print("Metrics for model: ", args.target_model) - print(all_metrics) + target_model = load_model(args.target_model, args.device, args.ov_config, args.hf, args.genai) + all_metrics_per_question, all_metrics = evaluator.score(target_model, genai_gen_answer if args.genai else None) + logger.info("Metrics for model: %s", args.target_model) + logger.info(all_metrics) if args.output: if not os.path.exists(args.output): @@ -269,11 +321,11 @@ def main(): actual_text += l2 + "\n" diff += diff_strings(l1, l2) + "\n" - print("--------------------------------------------------------------------------------------") - print("## Reference text {}:\n".format(i + 1), ref_text) - print("## Actual text {}:\n".format(i + 1), actual_text) - print("## Diff {}: ".format(i + 1)) - print(diff) + logger.info("--------------------------------------------------------------------------------------") + logger.info("## Reference text %d:\n%s", i + 1, ref_text) + logger.info("## Actual text %d:\n%s", i + 1, actual_text) + logger.info("## Diff %d: ", i + 1) + logger.info(diff) if __name__ == "__main__":