-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
158 lines (120 loc) · 5.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Script for training KWT model"""
from argparse import ArgumentParser
from config_parser import get_config
import os
import yaml
import torch
from torch import nn
import wandb
from utils.loss import LabelSmoothingLoss
from utils.opt import get_optimizer
from utils.scheduler import WarmUpLR, get_scheduler
from utils.trainer import train, evaluate
from utils.dataset import get_loader
from utils.misc import seed_everything, count_params, get_model, calc_step, log
def training_pipeline(config):
"""
Initiates and executes all the steps involved with KWT model training
:param config: KWT configuration
"""
config["exp"]["save_dir"] = os.path.join(config["exp"]["exp_dir"], config["exp"]["exp_name"])
os.makedirs(config["exp"]["save_dir"], exist_ok=True)
######################################
# save hyperparameters for current run
######################################
config_str = yaml.dump(config)
print("Using settings:\n", config_str)
with open(os.path.join(config["exp"]["save_dir"], "settings.txt"), "w+") as f:
f.write(config_str)
#####################################
# initialize training items
#####################################
# data
with open(config["train_list_file"], "r") as f:
train_list = f.read().rstrip().split("\n")
with open(config["val_list_file"], "r") as f:
val_list = f.read().rstrip().split("\n")
with open(config["test_list_file"], "r") as f:
test_list = f.read().rstrip().split("\n")
trainloader = get_loader(train_list, config, train=True)
valloader = get_loader(val_list, config, train=False)
testloader = get_loader(test_list, config, train=False)
# model
model = get_model(config["hparams"]["model"])
if args.ckpt:
ckpt = torch.load(args.ckpt, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
print(f"Loaded checkpoint {args.ckpt}.")
model = model.to(config["hparams"]["device"])
print(f"Created model with {count_params(model)} parameters.")
# loss
if config["hparams"]["l_smooth"]:
criterion = LabelSmoothingLoss(num_classes=config["hparams"]["model"]["num_classes"], smoothing=config["hparams"]["l_smooth"])
else:
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = get_optimizer(model, config["hparams"]["optimizer"])
# lr scheduler
schedulers = {
"warmup": None,
"scheduler": None
}
if config["hparams"]["scheduler"]["n_warmup"]:
schedulers["warmup"] = WarmUpLR(optimizer, total_iters=len(trainloader) * config["hparams"]["scheduler"]["n_warmup"])
if config["hparams"]["scheduler"]["scheduler_type"] is not None:
total_iters = len(trainloader) * max(1, (config["hparams"]["scheduler"]["max_epochs"] - config["hparams"]["scheduler"]["n_warmup"]))
schedulers["scheduler"] = get_scheduler(optimizer, config["hparams"]["scheduler"]["scheduler_type"], total_iters)
#####################################
# Training Run
#####################################
print("Initiating training.")
train(model, optimizer, criterion, trainloader, valloader, schedulers, config)
#####################################
# Final Test
#####################################
final_step = calc_step(config["hparams"]["n_epochs"] + 1, len(trainloader), len(trainloader) - 1)
# evaluating the final state (last.pth)
test_acc, test_loss = evaluate(model, criterion, testloader, config["hparams"]["device"])
log_dict = {
"test_loss_last": test_loss,
"test_acc_last": test_acc
}
log(log_dict, final_step, config)
# evaluating the best validation state (best.pth)
ckpt = torch.load(os.path.join(config["exp"]["save_dir"], "best.pth"))
model.load_state_dict(ckpt["model_state_dict"])
print("Best ckpt loaded.")
test_acc, test_loss = evaluate(model, criterion, testloader, config["hparams"]["device"])
log_dict = {
"test_loss_best": test_loss,
"test_acc_best": test_acc
}
log(log_dict, final_step, config)
def main(args):
"""
Calls training pipeline and sets up wandb logging if used
:param args: input arguments
"""
config = get_config(args.conf)
seed_everything(config["hparams"]["seed"])
if args.id:
config["exp"]["exp_name"] = config["exp"]["exp_name"] + args.id
if config["exp"]["wandb"]:
if config["exp"]["wandb_api_key"] is not None:
with open(config["exp"]["wandb_api_key"], "r") as f:
os.environ["WANDB_API_KEY"] = f.read()
elif os.environ.get("WANDB_API_KEY", False):
print("Found API key from env variable.")
else:
wandb.login()
with wandb.init(project=config["exp"]["proj_name"], name=config["exp"]["exp_name"], config=config["hparams"]):
training_pipeline(config)
else:
training_pipeline(config)
if __name__ == "__main__":
parser = ArgumentParser("Driver code.")
parser.add_argument("--conf", type=str, required=True, help="Path to config.yaml file.")
parser.add_argument("--ckpt", type=str, required=False, help="Path to checkpoint file.", default=None)
parser.add_argument("--id", type=str, required=False, help="Obtional experiment identifier.", default=None)
args = parser.parse_args()
main(args)