forked from mit-han-lab/llm-awq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
222 lines (204 loc) · 8.01 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
from attributedict.collections import AttributeDict
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
import tinychat.utils.constants
from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
from tinychat.utils.tune import device_warmup, tune_all_wqlinears
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# opt_params in TinyLLMEngine
gen_params = AttributeDict(
[
("seed", -1), # RNG seed
("n_threads", 1), # TODO: fix this
("n_predict", 512), # new tokens to predict
("n_parts", -1), # amount of model parts (-1: determine from model dimensions)
("n_ctx", 512), # context size
("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS)
("n_keep", 0), # number of tokens to keep from initial prompt
("n_vocab", 50272), # vocabulary size
# sampling parameters
("logit_bias", dict()), # logit bias for specific tokens: <int, float>
("top_k", 40), # <= 0 to use vocab size
("top_p", 0.95), # 1.0 = disabled
("tfs_z", 1.00), # 1.0 = disabled
("typical_p", 1.00), # 1.0 = disabled
("temp", 0.70), # 1.0 = disabled
("repeat_penalty", 1.10), # 1.0 = disabled
(
"repeat_last_n",
64,
), # last n tokens to penalize (0 = disable penalty, -1 = context size)
("frequency_penalty", 0.00), # 0.0 = disabled
("presence_penalty", 0.00), # 0.0 = disabled
("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
("mirostat_tau", 5.00), # target entropy
("mirostat_eta", 0.10), # learning rate
]
)
def stream_output(output_stream):
print(f"ASSISTANT: ", end="", flush=True)
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
if "timing" in outputs and outputs["timing"] is not None:
timing = outputs["timing"]
context_tokens = timing["context_tokens"]
context_time = timing["context_time"]
total_tokens = timing["total_tokens"]
generation_time_list = timing["generation_time_list"]
generation_tokens = len(generation_time_list)
average_speed = (context_time + np.sum(generation_time_list)) / (
context_tokens + generation_tokens
)
print("=" * 50)
print("Speed of Inference")
print("-" * 50)
# print(f"Context Stage : {context_time/context_tokens * 1000:.2f} ms/token")
print(
f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token"
)
# print(f"Average Speed : {average_speed * 1000:.2f} ms/token")
print("=" * 50)
# print("token num:", total_tokens)
# print("Model total Time = ", (context_time + np.sum(generation_time_list))*1000, "ms" )
return " ".join(output_text)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type", type=str, default="LLaMa", help="type of the model"
)
parser.add_argument(
"--model_path",
type=str,
default="/data/llm/checkpoints/vicuna-hf/vicuna-7b",
help="path to the model",
)
parser.add_argument(
"--precision", type=str, default="W4A16", help="compute precision"
)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--q_group_size", type=int, default=128)
parser.add_argument(
"--load_quant",
type=str,
default="/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt",
help="path to the pre-quanted 4-bit weights",
)
parser.add_argument(
"--max_seq_len",
type=int,
default=2048,
help="maximum sequence length for kv cache"
)
parser.add_argument(
"--max_batch_size",
type=int,
default=1,
help="maximum batch size for kv cache"
)
args = parser.parse_args()
assert args.model_type.lower() in [
"llama",
"falcon",
"mpt",
], "We only support llama & falcon & mpt now"
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
gen_params.n_predict = 512
gen_params.n_vocab = 32000
tinychat.utils.constants.max_batch_size = args.max_batch_size
tinychat.utils.constants.max_seq_len = args.max_seq_len
# TODO (Haotian): a more elegant implementation here.
# We need to update these global variables before models use them.
from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
# config.init_device="meta"
tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer_name, trust_remote_code=True
)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, use_fast=False, trust_remote_code=True
)
modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model_type_dict = {
"llama": LlamaForCausalLM,
"falcon": FalconForCausalLM,
"mpt": MPTForCausalLM,
}
if args.precision == "W4A16":
if args.model_type.lower() == "llama":
model = model_type_dict["llama"](config).half()
model = load_awq_llama_fast(
model, args.load_quant, 4, args.q_group_size, args.device
)
else:
model = (
model_type_dict[args.model_type.lower()](config).half()
)
model = load_awq_model(
model, args.load_quant, 4, args.q_group_size, args.device
)
else:
loaded_model = AutoModelForCausalLM.from_pretrained(
args.model_path,
config=config,
torch_dtype=torch.float16,
trust_remote_code=True,
)
model = model_type_dict[args.model_type.lower()](config).half().to(args.device)
model.load_state_dict(loaded_model.state_dict())
# device warm up
device_warmup(args.device)
# autotune split_k_iters
# tune_all_wqlinears(model)
# TODO (Haotian): Verify if the StreamGenerator still works for the unmodified falcon impl.
stream_generator = StreamGenerator
# Optimize AWQ quantized model
if args.precision == "W4A16" and args.model_type.lower() == "llama":
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
make_quant_attn(model, args.device)
make_quant_norm(model)
make_fused_mlp(model)
model_prompter = get_prompter(args.model_type, args.model_path)
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
count = 0
while True:
# Get input from the user
input_prompt = input("USER: ")
if input_prompt == "":
print("EXIT...")
break
model_prompter.insert_prompt(input_prompt)
output_stream = stream_generator(
model,
tokenizer,
model_prompter.model_input,
gen_params,
device=args.device,
stop_token_ids=stop_token_ids,
)
outputs = stream_output(output_stream)
model_prompter.update_template(outputs)
count += 1