-
Notifications
You must be signed in to change notification settings - Fork 275
/
eval.py
175 lines (146 loc) · 6.29 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Written by Yukang Chen
# Some code based on https://github.com/epfml/landmark-attention
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import torch
import argparse
import random
import numpy as np
from tqdm import tqdm
import transformers
from peft import PeftModel
from llama_attn_replace import replace_llama_attn
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--batch_size', type=int, default=32, help='batch size during inference')
parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
parser.add_argument('--cache_dir', type=str, default="./cache")
parser.add_argument('--seq_len', type=int, default=2048, help='context length during evaluation')
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
parser.add_argument('--peft_model', type=str, default=None, help='')
parser.add_argument('--flash_attn', type=bool, default=True, help='')
parser.add_argument('--data_path', type=str, default="./test.bin", help='')
args = parser.parse_args()
return args
def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256):
all_ix = list(range(0, len(data) - seq_length, sliding_window))
all_ix.pop()
for idx in range(0, len(all_ix), batch_size):
ix = all_ix[idx:idx+batch_size]
assert all([idx + seq_length + 1 <= len(data) for idx in ix])
x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
if device != 'cpu':
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
yield x, y
def iceildiv(x, y):
return (x + y - 1) // y
def evaluate(model, data, batch_size, device, seq_length, sliding_window=256, use_cache=False):
stats = {}
model.eval()
loss_list_val, acc_list = [], []
loss_step_list_val = []
with torch.no_grad():
print(f"Using seq length {seq_length}")
torch.set_printoptions(sci_mode=False)
for idx, (x, y) in tqdm(
enumerate(
get_as_batch(
data['val'],
seq_length,
batch_size,
device=device,
sliding_window=sliding_window
)
),
total=iceildiv(
iceildiv(len(data['val']), sliding_window),
batch_size
)
):
val_loss = 0.
acc = 0.
cnt = 0
for part_idx, i in enumerate(range(0, x.shape[1], seq_length)):
part_len = x[:, i:i + seq_length].shape[1]
outputs = model(
input_ids=x[:, i:i + seq_length],
labels=x[:, i:i+seq_length].contiguous(),
use_cache=use_cache)
val_loss = outputs.loss * part_len + val_loss
acc = ((outputs.logits.argmax(-1) == y[:, i:i+seq_length]).float().sum()) + acc
cnt += part_len
while len(loss_step_list_val) <= part_idx:
loss_step_list_val.append([])
loss_step_list_val[part_idx].append(outputs.loss.item())
val_loss /= cnt
acc /= cnt
loss_list_val.append(val_loss.item())
acc_list.append(acc.item())
stats['val_acc'] = torch.as_tensor(acc_list).mean().item()
stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item()
stats['val_perplexity'] = 2.71828 ** stats['val_loss']
stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1))
return stats
def main(args):
device = "cuda:0"
seed = 2
torch.cuda.set_device(device)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')}
print(f"Num validation tokens: {len(data['val'])}")
print("data path", args.data_path)
print("base model", args.base_model)
print("peft model", args.peft_model)
if args.flash_attn:
replace_llama_attn(use_flash_attn=True, use_full=True)
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
)
context_size = args.context_size if args.context_size > 0 else args.seq_len
orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models
if orig_ctx_len and context_size > orig_ctx_len:
scaling_factor = float(math.ceil(context_size / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
args.base_model,
config=config,
cache_dir=args.cache_dir,
torch_dtype=torch.float16,
device_map="auto",
)
model.resize_token_embeddings(32001)
if args.peft_model:
trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
if os.path.isfile(trainable_params):
model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
else:
raise ValueError("Trainable input embedding and normalization are required.")
model = PeftModel.from_pretrained(
model,
args.peft_model,
device_map="auto",
torch_dtype=torch.float16,
)
stats = evaluate(model, data, args.batch_size, device, args.seq_len, sliding_window=256)
print(stats)
if __name__ == "__main__":
args = parse_config()
main(args)