-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
47 lines (37 loc) · 1.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import argparse
import torch
from torch.utils.data import DataLoader
from datf.models import build_model
from datf.datasets import build_dataset
from configs.config_args import parse_train_configs
from datf.utils.trainer import ModelTrainer
from datf.optimizers import build_optimizer
from datf.losses import build_criterion
def train(args):
val_dataset = True
device = args.device
model = build_model(cfg=cfg)
if hasattr(cfg, "ckpt") and len(cfg.ckpt):
model.load_params_from_file(filename=cfg.ckpt)
print("[LOG] Loaded checkpoint")
if isinstance(model, list):
for m in model:
m = m.to(device)
ploss_criterion = None
if hasattr(cfg, "ploss_criterion"):
ploss_criterion = build_criterion(cfg=cfg)
# Send model to Device:
train_dataset, val_dataset, collate_fn = build_dataset(cfg=cfg)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True,
collate_fn=lambda x: collate_fn(x), num_workers=args.num_workers)
valid_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True,
collate_fn=lambda x: collate_fn(x), num_workers=1)
print(f'Train Examples: {len(train_dataset)} | Valid Examples: {len(val_dataset) if val_dataset else "None" }')
optimizer_list = build_optimizer(cfg=cfg, model=model.model)
trainer = ModelTrainer( model, train_loader, valid_loader, optimizer_list, exp_path = args.exp_path, \
cfg=cfg, device=device, ploss_criterion=ploss_criterion)
trainer.train(cfg.num_epochs)
if __name__ == "__main__":
cfg = parse_train_configs()
train(cfg)