forked from meta-llama/llama
-
Notifications
You must be signed in to change notification settings - Fork 18
/
example-chat.py
108 lines (90 loc) · 3.09 KB
/
example-chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
print("Creating model...")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
model = Transformer(model_args)
# Original copyright by tloen
# https://github.com/tloen/llama-int8/blob/main/example.py
key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}
for i, ckpt in enumerate(checkpoints):
print(f"Loading checkpoint {i}")
checkpoint = torch.load(ckpt, map_location="cpu")
for parameter_name, parameter in model.named_parameters():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None and i == 0:
parameter.data = checkpoint[parameter_name]
elif key_to_dim[short_name] == 0:
size = checkpoint[parameter_name].size(0)
parameter.data[size * i: size * (i + 1), :] = checkpoint[
parameter_name
]
elif key_to_dim[short_name] == -1:
size = checkpoint[parameter_name].size(-1)
parameter.data[:, size * i: size * (i + 1)] = checkpoint[
parameter_name
]
del checkpoint[parameter_name]
del checkpoint
model.to("cpu")
generator = LLaMA(model, tokenizer)
print(f"Loaded model in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str = './model',
tokenizer_path: str = './tokenizer/tokenizer.model',
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
):
# torch.manual_seed(1)
# torch.set_default_dtype(torch.bfloat16)
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
while True:
prompt = input(f'prompt> ')
if len(prompt.strip()) > 0:
prompts = [prompt]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)
for result in results:
print(result)
if __name__ == "__main__":
fire.Fire(main)