forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_pipeline.py
121 lines (100 loc) · 3.96 KB
/
training_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from typing import List, Optional
import hydra
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src import utils
log = utils.get_logger(__name__)
def train(config: DictConfig) -> Optional[float]:
"""Contains the training pipeline. Can additionally evaluate model on a testset, using best
weights achieved during training.
Args:
config (DictConfig): Configuration composed by Hydra.
Returns:
Optional[float]: Metric score for hyperparameter optimization.
"""
# Set seed for random number generators in pytorch, numpy and python.random
if config.get("seed"):
seed_everything(config.seed, workers=True)
# Convert relative ckpt path to absolute path if necessary
ckpt_path = config.trainer.get("resume_from_checkpoint")
if ckpt_path and not os.path.isabs(ckpt_path):
config.trainer.resume_from_checkpoint = os.path.join(
hydra.utils.get_original_cwd(), ckpt_path
)
# Init lightning datamodule
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
# Init lightning model
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(config.model)
# Init lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config.callbacks.items():
if "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
# Init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
# Send some parameters from config to all lightning loggers
log.info("Logging hyperparameters!")
utils.log_hyperparameters(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
# Train the model
if config.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule)
# Get metric score for hyperparameter optimization
optimized_metric = config.get("optimized_metric")
if optimized_metric and optimized_metric not in trainer.callback_metrics:
raise Exception(
"Metric for hyperparameter optimization not found! "
"Make sure the `optimized_metric` in `hparams_search` config is correct!"
)
score = trainer.callback_metrics.get(optimized_metric)
# Test the model
if config.get("test"):
ckpt_path = "best"
if not config.get("train") or config.trainer.get("fast_dev_run"):
ckpt_path = None
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
# Print path to best checkpoint
if not config.trainer.get("fast_dev_run") and config.get("train"):
log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}")
# Return metric score for hyperparameter optimization
return score