-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
111 lines (90 loc) · 3.23 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import json
import os
import torch
from torch.nn import functional as F
from tokenizer import Tokenizer
from model import KPT, KPTConfig
torch.manual_seed(6)
torch.cuda.manual_seed(6)
def run(
input_ids: torch.Tensor,
tokenizer: Tokenizer,
model: KPT,
max_length: int = 100,
num_sequences: int = 5,
) -> None:
# TODO: stop at end-token
tokens = input_ids.unsqueeze(0).repeat(num_sequences, 1)
while tokens.size(-1) < max_length:
with torch.no_grad():
# batch_size, context_length, vocab_size
logits = model(tokens)
# take logits at the last position
# batch_size, vocab_size
logits = logits[:, -1, :]
# get probabilities
probs = F.softmax(logits, dim=-1)
# top-k sampling of 100
topk_probs, topk_indices = torch.topk(probs, 100, dim=-1)
# Sample from the updated probabilities
sampled_token = torch.multinomial(topk_probs, 1)
# batch_size, 1
tok_col = torch.gather(topk_indices, -1, sampled_token)
tokens = torch.cat((tokens, tok_col), dim=1)
print("input > ", tokenizer.decode(input_ids.tolist()))
for i in range(num_sequences):
_tokens = tokens[i, :max_length].tolist()
decoded = tokenizer.decode(_tokens)
print(">", decoded)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint", type=str, help="path to model checkpoint", required=True
)
parser.add_argument(
"--type",
default="small",
nargs="?",
choices=["small", "large"],
help="type of model: small or large",
)
parser.add_argument(
"--tokenizer", type=str, help="path to tokenizer .bpe file", required=True
)
parser.add_argument("--input", type=str, help="input text", required=True)
parser.add_argument(
"--max-length",
type=int,
help="total generation tokens including input tokens",
required=True,
default=100,
)
args = parser.parse_args()
input_text = args.input.strip()
max_length = args.max_length
checkpoint_path = args.checkpoint
tokenizer_bpe_file_path = args.tokenizer
model_type = args.type
if input_text == "":
raise ValueError("Input cannot be empty")
if max_length < 0:
raise ValueError("max_length cannot be negative")
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError("Checkpoint not found")
if not os.path.isfile(tokenizer_bpe_file_path):
raise FileNotFoundError("Tokenizer bpe file not found")
model_configs = None
with open("config.json", "r") as f:
model_configs = json.load(f)[model_type]
model_configs = KPTConfig(**model_configs)
model = KPT(model_configs)
# TODO: add an argument for device
model.to("cuda")
model = torch.compile(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
tokenizer = Tokenizer(tokenizer_bpe_file_path)
input_ids = tokenizer.encode(input_text)
input_ids = torch.tensor(input_ids, dtype=torch.long, device="cuda")
run(input_ids, tokenizer, model, max_length)