This repository has been archived by the owner on Aug 10, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
390 lines (333 loc) · 16.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
#encoding: utf-8
import torch
from random import shuffle
#from torch import nn
from torch.optim import Adam as Optimizer
from loss.base import LabelSmoothingLoss
from lrsch import GoogleLR as LRScheduler
from parallel.base import DataParallelCriterion
from parallel.optm import MultiGPUGradScaler
from parallel.parallelMT import DataParallelMT
from transformer.NMT import NMT
from utils.base import free_cache, get_logger, mkdir, set_random_seed
from utils.contpara import get_model_parameters
from utils.fmt.base import iter_to_str
from utils.fmt.base4torch import load_emb, parse_cuda
from utils.h5serial import h5File
from utils.init.base import init_model_params
from utils.io import load_model_cpu, save_model, save_states
from utils.state.holder import Holder
from utils.state.pyrand import PyRandomState
from utils.state.thrand import THRandomState
from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode
from utils.tqdm import tqdm
from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam
from utils.train.dss import dynamic_sample
import cnfg.base as cnfg
from cnfg.ihyp import *
from cnfg.vocab.base import pad_id
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None):
sum_loss = part_loss = 0.0
sum_wd = part_wd = 0
_done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None
global minerr, minloss, wkdir, save_auto_clean, namin
model.train()
cur_b, _ls = 1, {} if save_loss else None
src_grp, tgt_grp = td["src"], td["tgt"]
for i_d in tqdm(tl, mininterval=tqdm_mininterval):
seq_batch = torch.from_numpy(src_grp[i_d][()])
seq_o = torch.from_numpy(tgt_grp[i_d][()])
lo = seq_o.size(1) - 1
if mv_device:
seq_batch = seq_batch.to(mv_device, non_blocking=True)
seq_o = seq_o.to(mv_device, non_blocking=True)
seq_batch, seq_o = seq_batch.long(), seq_o.long()
oi = seq_o.narrow(1, 0, lo)
ot = seq_o.narrow(1, 1, lo).contiguous()
with torch_autocast(enabled=_use_amp):
output = model(seq_batch, oi)
loss = lossf(output, ot)
if multi_gpu:
loss = loss.sum()
loss_add = loss.data.item()
# scale the sum of losses down according to the number of tokens adviced by: https://mp.weixin.qq.com/s/qAHZ4L5qK3rongCIIq5hQw, I think not reasonable.
#loss /= wd_add
if scaler is None:
loss.backward()
else:
scaler.scale(loss).backward()
wd_add = ot.ne(pad_id).int().sum().item()
loss = output = oi = ot = seq_batch = seq_o = None
sum_loss += loss_add
if save_loss:
_ls[i_d] = loss_add / wd_add
sum_wd += wd_add
_done_tokens += wd_add
if _done_tokens >= tokens_optm:
optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none)
_done_tokens = 0
if _cur_rstep is not None:
if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0):
if num_checkpoint > 1:
_fend = "_%d.h5" % (_cur_checkid)
_chkpf = chkpf[:-3] + _fend
_cur_checkid = (_cur_checkid + 1) % num_checkpoint
else:
_chkpf = chkpf
save_model(model, _chkpf, multi_gpu, print_func=logger.info)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
_cur_rstep -= 1
if _cur_rstep <= 0:
break
lrsch.step()
if nreport is not None:
part_loss += loss_add
part_wd += wd_add
if cur_b % nreport == 0:
if report_eva:
_leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp)
logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,))
if (_eeva < minerr) or (_leva < minloss):
save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
logger.info("New best model saved")
namin = 0
if _eeva < minerr:
minerr = _eeva
if _leva < minloss:
minloss = _leva
free_cache(mv_device)
model.train()
else:
logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,))
part_loss = 0.0
part_wd = 0
if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain):
if num_checkpoint > 1:
_fend = "_%d.h5" % (_cur_checkid)
_chkpf = chkpf[:-3] + _fend
_cur_checkid = (_cur_checkid + 1) % num_checkpoint
else:
_chkpf = chkpf
#save_model(model, _chkpf, isinstance(model, nn.DataParallel), print_func=logger.info)
save_model(model, _chkpf, multi_gpu, print_func=logger.info)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
cur_b += 1
if part_wd != 0.0:
logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,))
return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls
def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False):
r = w = 0
sum_loss = 0.0
model.eval()
src_grp, tgt_grp = ed["src"], ed["tgt"]
with torch_inference_mode():
for i in tqdm(range(nd), mininterval=tqdm_mininterval):
bid = str(i)
seq_batch = torch.from_numpy(src_grp[bid][()])
seq_o = torch.from_numpy(tgt_grp[bid][()])
lo = seq_o.size(1) - 1
if mv_device:
seq_batch = seq_batch.to(mv_device, non_blocking=True)
seq_o = seq_o.to(mv_device, non_blocking=True)
seq_batch, seq_o = seq_batch.long(), seq_o.long()
ot = seq_o.narrow(1, 1, lo).contiguous()
with torch_autocast(enabled=use_amp):
output = model(seq_batch, seq_o.narrow(1, 0, lo))
loss = lossf(output, ot)
if multi_gpu:
loss = loss.sum()
trans = torch.cat([outu.argmax(-1).to(mv_device, non_blocking=True) for outu in output], 0)
else:
trans = output.argmax(-1)
sum_loss += loss.data.item()
data_mask = ot.ne(pad_id)
correct = (trans.eq(ot) & data_mask).int()
w += data_mask.int().sum().item()
r += correct.sum().item()
correct = data_mask = trans = loss = output = ot = seq_batch = seq_o = None
w = float(w)
return sum_loss / w, (w - r) / w * 100.0
def hook_lr_update(optm, flags=None):
reset_Adam(optm, flags)
def init_fixing(module):
if hasattr(module, "fix_init"):
module.fix_init()
def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()
rid = cnfg.run_id
earlystop = cnfg.earlystop
maxrun = cnfg.maxrun
tokens_optm = cnfg.tokens_optm
done_tokens = 0
batch_report = cnfg.batch_report
report_eva = cnfg.report_eva
use_ams = cnfg.use_ams
cnt_states = cnfg.train_statesf
save_auto_clean = cnfg.save_auto_clean
overwrite_eva = cnfg.overwrite_eva
save_every = cnfg.save_every
start_chkp_save = cnfg.epoch_start_checkpoint_save
epoch_save = cnfg.epoch_save
remain_steps = cnfg.training_steps
wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/"))
mkdir(wkdir)
chkpf = None
statesf = None
if save_every is not None:
chkpf = wkdir + "checkpoint.h5"
if cnfg.save_train_state:
statesf = wkdir + "train.states.t7"
logger = get_logger(wkdir + "train.log")
use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid)
multi_gpu_optimizer = multi_gpu and cnfg.multi_gpu_optimizer
set_random_seed(cnfg.seed, use_cuda)
td = h5File(cnfg.train_data, "r")
vd = h5File(cnfg.dev_data, "r")
ntrain = td["ndata"][()].item()
nvalid = vd["ndata"][()].item()
nword = td["nword"][()].tolist()
nwordi, nwordt = nword[0], nword[-1]
tl = [str(i) for i in range(ntrain)]
logger.info("Design models with seed: %d" % torch.initial_seed())
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
fine_tune_m = cnfg.fine_tune_m
mymodel = init_model_params(mymodel)
mymodel.apply(init_fixing)
if fine_tune_m is not None:
logger.info("Load pre-trained model from: " + fine_tune_m)
mymodel = load_model_cpu(fine_tune_m, mymodel)
mymodel.apply(load_fixing)
#lossf = NLLLoss(ignore_index=pad_id, reduction="sum")
lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction="sum", forbidden_index=cnfg.forbidden_indexes)
if cnfg.src_emb is not None:
logger.info("Load source embedding from: " + cnfg.src_emb)
load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb)
if cnfg.tgt_emb is not None:
logger.info("Load target embedding from: " + cnfg.tgt_emb)
load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb)
if cuda_device:
mymodel.to(cuda_device, non_blocking=True)
lossf.to(cuda_device, non_blocking=True)
use_amp = cnfg.use_amp and use_cuda
scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None
if multi_gpu:
#mymodel = nn.DataParallel(mymodel, device_ids=cuda_devices, output_device=cuda_device.index)
mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False)
lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True)
if multi_gpu:
optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters)
else:
# lr will be over written by LRScheduler before used
optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams)
optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none)
# lrsch.step() will be automatically called with the constructor
lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)
mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs)
lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs)
state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)})
num_checkpoint = cnfg.num_checkpoint
cur_checkid = 0
tminerr = inf_default
minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,))
if fine_tune_m is None:
save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info)
logger.info("Initial model saved")
else:
if cnt_states is not None:
logger.info("Loading training states")
_remain_states = state_holder.load_state_dict(torch.load(cnt_states))
remain_steps, cur_checkid = _remain_states["remain_steps"], _remain_states["checkpoint_id"]
if "training_list" in _remain_states:
_ctl = _remain_states["training_list"]
else:
shuffle(tl)
_ctl = tl
tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, _ctl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler)
_ctl = _remain_states = None
vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,))
save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
logger.info("New best model saved")
if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0:
dss_ws = int(cnfg.dss_ws * ntrain)
_Dws = {}
_prev_Dws = {}
_crit_inc = {}
if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0:
dss_rm = int(cnfg.dss_rm * ntrain * (1.0 - cnfg.dss_ws))
else:
dss_rm = 0
else:
dss_ws = 0
dss_rm = 0
_Dws = None
namin = 0
for i in range(1, maxrun + 1):
shuffle(tl)
free_cache(use_cuda)
terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler)
vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,))
if (vprec <= minerr) or (vloss <= minloss):
save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
logger.info("New best model saved")
namin = 0
if vprec < minerr:
minerr = vprec
if vloss < minloss:
minloss = vloss
else:
if terr < tminerr:
tminerr = terr
save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
elif epoch_save:
save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
namin += 1
if namin >= earlystop:
if done_tokens > 0:
optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer)
lrsch.step()
done_tokens = 0
logger.info("early stop")
break
if remain_steps is not None and remain_steps <= 0:
logger.info("Last training step reached")
break
if dss_ws > 0:
if _prev_Dws:
for _key, _value in _Dws.items():
if _key in _prev_Dws:
_ploss = _prev_Dws[_key]
_crit_inc[_key] = (_ploss - _value) / _ploss
tl = dynamic_sample(_crit_inc, dss_ws, dss_rm)
_prev_Dws = _Dws
#oldlr = getlr(optimizer)
#lrsch.step(terr)
#newlr = getlr(optimizer)
#if updated_lr(oldlr, newlr):
#logger.info("".join(("lr update from: ", ",".join(iter_to_str(oldlr)), ", to: ", ",".join(iter_to_str(newlr)))))
#hook_lr_update(optimizer, use_ams)
if done_tokens > 0:
optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer)
lrsch.step()
#done_tokens = 0
save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info)
if statesf is not None:
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
logger.info("model saved")
td.close()
vd.close()