-
Notifications
You must be signed in to change notification settings - Fork 62
/
example.py
129 lines (106 loc) · 5.91 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# transfromers version 4.38.2
# this example is tested with 4 RTX3090s, 24GB memory each
import warnings
warnings.filterwarnings("ignore")
import torch
import json
import time
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import SelfExtend
window_size = 1024
group_size = 32
use_flash = True
# model_lists = ['google/gemma-7b-it', 'meta-llama/Llama-2-7b-chat-hf', 'mistralai/Mistral-7B-Instruct-v0.1', ]
model_lists = ['meta-llama/Llama-2-7b-chat-hf']
for model_name in model_lists:
if 'Mistral' in model_name:
# Disable Mistral's sliding window
config = AutoConfig.from_pretrained(model_name)
config.sliding_window = None
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, device_map="auto", torch_dtype=torch.bfloat16, use_flash_attention_2=use_flash)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, use_flash_attention_2=use_flash)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
file_name = "passkey_examples.jsonl"
print("=========="*2 + "**Original**" + "=========="*2)
for line in open(file_name, "r"):
example = json.loads(line)
prompt_postfix = "What is the pass key? The pass key is "
prompt = example["input"] + prompt_postfix
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
print( "-----------------------------------" )
print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " )
print( "Passkey target:", example["target"] )
start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=len(example["target"]))
end_time = time.time()
answer = prompt_postfix + tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
answer = answer.replace("\n", "\\n")
answer= f"{model_name}:\n [ {answer} ]"
print( answer )
print( f"Runing Time: {end_time - start_time:.2f} sec" )
print( "-----------------------------------\n" )
print("=========="*2 + "**SelfExtend using flash_attn**" + "=========="*2)
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
for line in open(file_name, "r"):
example = json.loads(line)
prompt_postfix = "What is the pass key? The pass key is "
prompt = example["input"] + prompt_postfix
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " )
print( "Passkey target:", example["target"] )
start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=len(example["target"]))
end_time = time.time()
answer = prompt_postfix + tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
answer = answer.replace("\n", "\\n")
answer= f"SelfExtended-{model_name}:\n [ {answer} ]"
print( answer )
print( f"Runing Time: {end_time - start_time:.2f} sec" )
print( "-----------------------------------\n" )
print("=========="*2 + "**SelfExtend using triton**" + "=========="*2)
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="triton") ## flash_attention_impl="triton" or "flash_attn"
for line in open(file_name, "r"):
example = json.loads(line)
prompt_postfix = "What is the pass key? The pass key is "
prompt = example["input"] + prompt_postfix
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " )
print( "Passkey target:", example["target"] )
start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=len(example["target"]))
end_time = time.time()
answer = prompt_postfix + tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
answer = answer.replace("\n", "\\n")
answer= f"SelfExtended-{model_name}:\n [ {answer} ]"
print( answer )
print( f"Runing Time: {end_time - start_time:.2f} sec" )
print( "-----------------------------------\n" )
print("=========="*2 + "**SelfExtend using Torch**" + "=========="*2)
print( "------Need more GPU memory!!-----------------------------" )
if 'Mistral' in model_name:
# Disable Mistral's sliding window
config = AutoConfig.from_pretrained(model_name)
config.sliding_window = None
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, device_map="auto", torch_dtype=torch.bfloat16)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
for line in open(file_name, "r"):
example = json.loads(line)
prompt_postfix = "What is the pass key? The pass key is "
prompt = example["input"] + prompt_postfix
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " )
print( "Passkey target:", example["target"] )
start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=len(example["target"]))
end_time = time.time()
answer = prompt_postfix + tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
answer = answer.replace("\n", "\\n")
answer= f"SelfExtended-{model_name}:\n [ {answer} ]"
print( answer )
print( f"Runing Time: {end_time - start_time:.2f} sec" )
print( "-----------------------------------\n" )