-
Notifications
You must be signed in to change notification settings - Fork 6
/
llamassp.py
290 lines (256 loc) · 11.5 KB
/
llamassp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import os
import argparse
import logging
from lssp.base import create_model
from lssp.base import sample_model
from lssp import evals
from lssp.ssp import ssp
import sys
import time
import torch
from transformers import LlamaTokenizer
from termcolor import colored
torch.manual_seed(1339)
MAX_NEW_TOKENS = 64
llama7b_name = 'decapoda-research/llama-7b-hf'
llama13b_name = 'decapoda-research/llama-13b-hf'
llama30b_name = 'decapoda-research/llama-30b-hf'
llama65b_name = 'decapoda-research/llama-65b-hf'
batch_size = 1
texts = [
'In which country is Hamburg?\n',
'How are you doing today?\n',
'It was a dark and stormy night.',
'The sun rose slowly over the horizon, casting a warm glow on the world below.',
'I never believed in ghosts until the day I met one.',
'The sound of the train whistle echoed through the valley as I stood at the station, waiting.',
'She walked into the room and everything changed.',
'The smell of freshly baked bread filled the air as I entered the bakery.',
'The first time I saw her, I knew she was trouble.'
'The world was ending, and I was the only one who knew.',
'It was the best of times, it was the worst of times.',
'The forest was alive with the sound of animals as I walked deeper into the woods.',
'As I looked out over the city, I knew that anything was possible.',
'The sound of gunfire echoed through the streets as I ran for cover.',
'The waves crashed against the shore, a never-ending cycle of destruction and creation.',
'I woke up to find myself in a strange place, with no memory of how I got there.',
'The clock struck midnight, and I knew that my life would never be the same.',]
tokenizer = LlamaTokenizer.from_pretrained(llama7b_name)
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_mem = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
def max_memory(gpus, starting_gpu=0):
return {i: max_mem for i in range(starting_gpu, n_gpus)}
def time_model(model):
# time the first run
input_ids = tokenizer(texts[0], return_tensors="pt").input_ids
input_ids = torch.stack([input_ids[0]] * batch_size).to(model.device)
generated_ids = sample_model(model, input_ids, MAX_NEW_TOKENS)
start_time = time.time()
nb_tokens = 0
for text in texts[1:]:
print("Completing text:", text)
intermediate_time = time.time()
input_ids = tokenizer(text, return_tensors="pt").input_ids
input_ids = torch.stack([input_ids[0]] * batch_size).to(model.device)
generated_ids = sample_model(model, input_ids, MAX_NEW_TOKENS)
nb_tokens += generated_ids.shape[1] - input_ids.shape[1]
print("Completion: ", tokenizer.decode(
generated_ids[0], skip_special_tokens=True))
print("Time: {:.2f}s".format(time.time() - intermediate_time))
print("========\n")
ms_per_token = (time.time() - start_time)*1000 / nb_tokens
return generated_ids, ms_per_token
def print_results(tokens_s, outputs, name='Noname'):
print("Results for ", name)
print(f"Ms per token: {tokens_s:.2f}ms")
print("========\n")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print("========\n")
models_params = {
'7B_8bit': {'model_name': llama7b_name,
'max_memory': max_memory(1),
'load_in_8bit': True},
'7B_8bit_4': {'model_name': llama7b_name,
'max_memory': max_memory(4),
'load_in_8bit': True},
'7B': {'model_name': llama7b_name,
'max_memory': max_memory(1),
'load_in_8bit': False},
'7B_8': {'model_name': llama7b_name,
'max_memory': max_memory(8),
'load_in_8bit': False},
'13B_8bit': {'model_name': llama13b_name,
'max_memory': max_memory(1),
'load_in_8bit': True},
'13B': {'model_name': llama13b_name,
'max_memory': max_memory(2),
'load_in_8bit': False},
'30B_8bit': {'model_name': llama30b_name,
'max_memory': max_memory(2),
'load_in_8bit': True},
'30B': {'model_name': llama30b_name,
'max_memory': max_memory(4),
'load_in_8bit': False},
'65B_8bit': {'model_name': llama65b_name,
'max_memory': max_memory(4),
'load_in_8bit': True},
'65B': {'model_name': llama65b_name,
'max_memory': max_memory(8),
'load_in_8bit': False},
'65B_v2': {'model_name': f"{os.getenv('HOME')}/data/hf-weights/65B",
'max_memory': max_memory(8),
'load_in_8bit': False},
}
def time_ssp(target_name, draft_name, K=4):
draft_model = create_model(**models_params[draft_name])
target_model = create_model(**models_params[target_name])
nb_tokens = 0
# Warmup
input_ids = tokenizer(texts[0], return_tensors="pt").input_ids
input_ids = torch.stack(
[input_ids[0]] * batch_size).to(draft_model.device)
generated_ids = ssp(target_model,
draft_model,
MAX_NEW_TOKENS,
input_ids, K=K)
start_time = time.time()
for text in texts[1:]:
print("Completing text:", text)
intermediate_time = time.time()
input_ids = tokenizer(text, return_tensors="pt").input_ids
input_ids = torch.stack(
[input_ids[0]] * batch_size).to(draft_model.device)
generated_ids = ssp(target_model,
draft_model,
MAX_NEW_TOKENS,
input_ids, K=K)
nb_tokens += generated_ids.shape[1] - input_ids.shape[1]
print("Completion: ", tokenizer.decode(
generated_ids[0], skip_special_tokens=True))
print("Time: {:.2f}s".format(time.time() - intermediate_time))
print("========\n")
ms_per_token = (time.time() - start_time)*1000 / nb_tokens
return generated_ids, ms_per_token
def print_speeds(speeds):
print("Speeds:")
for model_name, tokens_s in speeds.items():
print('-'*20)
print(f"{model_name} | {tokens_s:.2f}ms")
print('-'*20)
def models_raw_speed():
speeds = {}
del models_params['7B'], models_params['13B'], models_params['30B']
for model_name, params in sorted(models_params.items()):
print(f"Testing {model_name}")
print('-'*20)
model = create_model(**params)
outputs, tokens_s = time_model(model)
speeds[model_name] = tokens_s
print_results(tokens_s, outputs, model_name)
del model
torch.cuda.empty_cache()
print_speeds(speeds)
draft_name = '7B_8bit'
target_name = '65B_8bit'
print(f"Testing SSP {draft_name} / {target_name}")
tokens_s, outputs = time_ssp(draft_name, target_name)
speeds[f"{draft_name} / {target_name}"] = tokens_s
print(speeds)
def show_comparative_speeds(text, model, draft_model):
input_ids = tokenizer(text, return_tensors="pt").input_ids
print(colored("=> Regular sampling with target model",
attrs=['bold']))
sys.stdout.write(text)
start_time = time.time()
sample_model(model, input_ids, MAX_NEW_TOKENS, display=True)
print("\nTime: "
+ colored(f"{time.time() - start_time:.2f}s", 'red', attrs=['bold']))
print(colored(
"=> Speculative sampling with target model helped by draft model",
attrs=['bold']))
sys.stdout.write(text)
start_time = time.time()
ssp(model, draft_model, MAX_NEW_TOKENS,
input_ids, K=4, display=True)
print("\nTime: "
+ colored(f"{time.time() - start_time:.2f}s", 'green', attrs=['bold']))
def create_argument_parser():
"""
Create a parser for the command-line arguments, with 'compare', 'latency'
and 'eval' subcommands
"""
parser = argparse.ArgumentParser(
description='Test speeds of Llama models with regular sampling and speculative sampling: measure their latency, compare their speed, and evaluate their performance on a simple task.')
# add argument to set log level
parser.add_argument(
'-v', '--verbose', action='store_true', help='verbose output')
subparsers = parser.add_subparsers(dest='subcommand')
compare_parser = subparsers.add_parser(
'compare', help='Compare the speed of a given model (target model) alone, and with speculative sampling with another model (draft model)')
compare_parser.add_argument('model', help='Name of target model')
compare_parser.add_argument('draft', help='Draft model')
latency_parser = subparsers.add_parser(
'latency', help='Measure model latency in ms per token')
latency_parser.add_argument('model', help='Name of model')
latency_parser.add_argument(
'--draft', help='Draft model; if specified, will measure the latency of speculative sampling with the draft model rather than the regular latency')
eval_parser = subparsers.add_parser(
'eval', help='evaluate a model')
eval_parser.add_argument('model', help='model to use')
eval_parser.add_argument(
'--draft', help='Draft model; if specified, will evaluate the model with speculative sampling with the draft model rather than the regular model')
eval_parser.add_argument('--seed', type=int, default=1338,
help='Seed for randomly creating the eval prompts')
eval_parser.add_argument('--nb-prompts', type=int, default=1000,
help='Number of eval prompts to create')
return parser
if __name__ == "__main__":
parser = create_argument_parser()
args = parser.parse_args()
if args.verbose:
# set log level to debug
logging.basicConfig(level=logging.DEBUG)
if args.subcommand == 'compare':
model = create_model(**models_params[args.model])
draft_model = create_model(**models_params[args.draft])
print("Warming up")
ssp(model, draft_model, MAX_NEW_TOKENS,
tokenizer(texts[0], return_tensors="pt").input_ids, K=4)
print(
f"Comparing {args.model} model regular sampling and {args.model} SSp with {args.draft} draft model\n====\n")
# Read from stdin until EOF
while True:
try:
sys.stdout.write("> ")
sys.stdout.flush()
text = input()
except EOFError:
break
show_comparative_speeds(text, model, draft_model)
elif (args.subcommand == 'latency' and args.draft):
print(f"Testing {args.model} with draft {args.draft}")
print('-'*20)
gen_ids, ms_per_token = time_ssp(args.model, args.draft)
print_results(ms_per_token, gen_ids, args.model)
elif (args.subcommand == 'latency'):
print(f"Testing {args.model}")
print('-'*20)
model = create_model(**models_params[args.model])
gen_ids, ms_per_token = time_model(model)
print_results(ms_per_token, gen_ids, args.model)
elif (args.subcommand == 'eval'):
print(f"Eval of {args.model} on multiplication task (seed {args.seed})"
+ (f" with draft {args.draft}" if args.draft else ""))
print('-'*20)
model = create_model(**models_params[args.model])
if args.draft:
draft_model = create_model(**models_params[args.draft])
else:
draft_model = None
results = evals.measure_model_score(
model, tokenizer, args.nb_prompts, args.seed, draft_model)
evals.print_results(results, args.model, args.draft)
else:
# show usage
parser.print_help()