forked from victoresque/pytorch-template
-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.py
81 lines (66 loc) · 2.56 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import torch
import torch.distributed as dist
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
from pathlib import Path
from srcs.trainer import Trainer
from srcs.utils import get_logger
# fix random seeds for reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
def train_worker(config):
logger = get_logger('train')
# setup data_loader instances
data_loader, valid_data_loader = instantiate(config.data_loader)
# build model. print it's structure and # trainable params.
model = instantiate(config.arch)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
logger.info(model)
logger.info(f'Trainable parameters: {sum([p.numel() for p in trainable_params])}')
# get function handles of loss and metrics
criterion = instantiate(config.loss)
metrics = {
met_name: instantiate(met)
for met_name, met in config.metrics.items()
}
# build optimizer, learning rate scheduler.
optimizer = instantiate(config.optimizer, model.parameters())
lr_scheduler = instantiate(config.lr_scheduler, optimizer)
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
trainer.train()
def init_worker(rank, ngpus, working_dir, config):
# initialize training config
config = OmegaConf.create(config)
config.local_rank = rank
config.cwd = working_dir
# prevent access to non-existing keys
OmegaConf.set_struct(config, True)
dist.init_process_group(
backend='nccl',
init_method='tcp://127.0.0.1:34567',
world_size=ngpus,
rank=rank)
torch.cuda.set_device(rank)
# start training processes
train_worker(config)
@hydra.main(config_path='conf/', config_name='train', version_base='1.1')
def main(config):
n_gpu = torch.cuda.device_count()
assert n_gpu, 'Can\'t find any GPU device on this machine.'
working_dir = str(Path.cwd().relative_to(hydra.utils.get_original_cwd()))
if config.resume is not None:
config.resume = hydra.utils.to_absolute_path(config.resume)
config = OmegaConf.to_yaml(config, resolve=True)
torch.multiprocessing.spawn(init_worker, nprocs=n_gpu, args=(n_gpu, working_dir, config))
if __name__ == '__main__':
# pylint: disable=no-value-for-parameter
main()