-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_supervised.py
143 lines (111 loc) · 4.4 KB
/
train_supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sys
import os
import timeit
import torch
from torch import optim
from torch.utils import data as torch_data
from tabulate import tabulate
import wandb
import numpy as np
from pathlib import Path
from utils import networks, datasets, loss_functions, evaluation, experiment_manager
def run_training(cfg):
run_config = {
'CONFIG_NAME': cfg.NAME,
'device': device,
'epochs': cfg.TRAINER.EPOCHS,
'learning rate': cfg.TRAINER.LR,
'batch size': cfg.TRAINER.BATCH_SIZE,
}
table = {'run config name': run_config.keys(),
' ': run_config.values(),
}
print(tabulate(table, headers='keys', tablefmt="fancy_grid", ))
net = networks.create_network(cfg)
net.to(device)
optimizer = optim.AdamW(net.parameters(), lr=cfg.TRAINER.LR, weight_decay=0.01)
criterion = loss_functions.get_criterion(cfg.MODEL.LOSS_TYPE)
# reset the generators
dataset = datasets.SpaceNet7CDDataset(cfg=cfg, run_type='training')
print(dataset)
dataloader_kwargs = {
'batch_size': cfg.TRAINER.BATCH_SIZE,
'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER,
'shuffle': cfg.DATALOADER.SHUFFLE,
'drop_last': True,
'pin_memory': True,
}
dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs)
# unpacking cfg
epochs = cfg.TRAINER.EPOCHS
save_checkpoints = cfg.SAVE_CHECKPOINTS
steps_per_epoch = len(dataloader)
# tracking variables
global_step = epoch_float = 0
for epoch in range(1, epochs + 1):
print(f'Starting epoch {epoch}/{epochs}.')
start = timeit.default_timer()
loss_set = []
for i, batch in enumerate(dataloader):
net.train()
optimizer.zero_grad()
x_t1 = batch['x_t1'].to(device)
x_t2 = batch['x_t2'].to(device)
logits, _, _ = net(x_t1, x_t2)
gt_change = batch['y_change'].to(device)
loss = criterion(logits, gt_change)
loss.backward()
optimizer.step()
loss_set.append(loss.item())
global_step += 1
epoch_float = global_step / steps_per_epoch
if global_step % cfg.LOG_FREQ == 0:
print(f'Logging step {global_step} (epoch {epoch_float:.2f}).')
# evaluation on sample of training and validation set
evaluation.model_evaluation(net, cfg, device, 'training', epoch_float, global_step)
evaluation.model_evaluation(net, cfg, device, 'validation', epoch_float, global_step)
# logging
time = timeit.default_timer() - start
wandb.log({
'loss': np.mean(loss_set),
'labeled_percentage': 100,
'time': time,
'step': global_step,
'epoch': epoch_float,
})
start = timeit.default_timer()
loss_set = []
# end of batch
assert (epoch == epoch_float)
print(f'epoch float {epoch_float} (step {global_step}) - epoch {epoch}')
# evaluation at the end of an epoch
evaluation.model_evaluation(net, cfg, device, 'training', epoch_float, global_step)
evaluation.model_evaluation(net, cfg, device, 'validation', epoch_float, global_step)
evaluation.model_evaluation(net, cfg, device, 'test', epoch_float, global_step)
if epoch in save_checkpoints and not cfg.DEBUG:
print(f'saving network', flush=True)
networks.save_checkpoint(net, optimizer, epoch, global_step, cfg)
if __name__ == '__main__':
args = experiment_manager.default_argument_parser().parse_known_args()[0]
cfg = experiment_manager.setup_cfg(args)
# make training deterministic
torch.manual_seed(cfg.SEED)
np.random.seed(cfg.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('=== Runnning on device: p', device)
wandb.init(
name=cfg.NAME,
config=cfg,
project='siamese_ssl',
tags=['ssl', 'cd', 'siamese', 'spacenet7', ],
mode='online' if not cfg.DEBUG else 'disabled',
)
try:
run_training(cfg)
except KeyboardInterrupt:
try:
sys.exit(0)
except SystemExit:
os._exit(0)