From 7380823f3f0a38d42e9a979341315391ab1dacd0 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:49:01 +0800 Subject: [PATCH] Update Llama2 multi-processes example (#11852) * update llama2 multi-processes examples * update * update readme * update --- .../HF-Transformers-AutoModels/LLM/README.md | 18 ++++++++++--- .../HF-Transformers-AutoModels/LLM/llama2.py | 27 ++++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index 068c2b2d4c9..31e055b5bea 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -124,17 +124,27 @@ python  llama2.py Arguments info: - `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `What is AI?`. - `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. +- `--max-output-len MAX_OUTPUT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`. +- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `768`. + #### Sample Output #### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) ```log Inference time: xxxx s --------------------- Prompt -------------------- -Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun +-------------------- Input -------------------- + [INST] <> + +<> + +What is AI? [/INST] -------------------- Output -------------------- - Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun and exciting experiences. + [INST] <> + +<> -One day, she decided to go on a journey to find a magical land that was said to be full of wonders +What is AI? [/INST] AI (Artificial Intelligence) is a field of computer science and engineering that focuses on the development of intelligent machines that can perform tasks ``` diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py index a5945aa4501..d23a6405677 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py @@ -26,6 +26,18 @@ logger = logging.get_logger(__name__) +def get_prompt(message: str, chat_history: list[tuple[str, str]], + system_prompt: str) -> str: + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + # The first user input is _not_ stripped + do_strip = False + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -38,9 +50,11 @@ help="The huggingface repo id for the Llama2 model to be downloaded" ", or the path to the huggingface checkpoint folder", ) + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-output-len", type=int, default=1024) - parser.add_argument("--max-prompt-len", type=int, default=128) + parser.add_argument("--max-prompt-len", type=int, default=768) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=2) parser.add_argument("--inter-pp", type=int, default=2) @@ -64,20 +78,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - prompts = [ - "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", - "Once upon a time, there existed", - "Once upon a time, there existed a little girl who liked to have adventures.", - ] + DEFAULT_SYSTEM_PROMPT = """\ + """ print("-" * 80) print("done") with torch.inference_mode(): print("finish to load") for i in range(5): - import random - idx = random.randint(0, 2) - prompt = prompts[idx] + prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) _input_ids = tokenizer.encode(prompt, return_tensors="pt") print("input length:", len(_input_ids[0])) st = time.time()