-
Notifications
You must be signed in to change notification settings - Fork 4
/
runner.py
246 lines (201 loc) · 10.7 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import logging
import os
import random
from abc import ABC
from argparse import Namespace
from pathlib import Path
from shutil import copy2
from typing import Union
import comet_ml # noqa
import hydra
import numpy as np
import torch
from dotenv import load_dotenv
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CometLogger, Logger
from vital.data.data_module import VitalDataModule
from vital.system import VitalSystem
from vital.utils.config import (
instantiate_config_node_leaves,
instantiate_results_processor,
register_omegaconf_resolvers,
)
from vital.utils.saving import resolve_model_checkpoint_path
logger = logging.getLogger(__name__)
class VitalRunner(ABC):
"""Abstract runner that runs the main training/val loop, etc. using Lightning Trainer."""
@classmethod
def main(cls) -> None:
"""Runs the requested experiment."""
# Set up the environment
cls.pre_run_routine()
# Run the system with config loaded by @hydra.main
cls.run_system()
@classmethod
def pre_run_routine(cls) -> None:
"""Sets-up the environment before running the training/testing."""
# Load environment variables from `.env` file if it exists
# Load before hydra main to allow for setting environment variables with ${oc.env:ENV_NAME}
load_dotenv()
register_omegaconf_resolvers()
@staticmethod
@hydra.main(version_base=None, config_path="config", config_name="vital_default")
def run_system(cfg: DictConfig) -> None:
"""Handles the training and evaluation of a model.
Note: Must be static because of the hydra.main decorator and config pass-through.
Args:
cfg: Configuration to run the experiment.
"""
cfg = VitalRunner._check_cfg(cfg)
# Global torch config making it possible to use performant matrix multiplications on Ampere and newer CUDA GPUs
# `vital` default value of is `high`, the middle-ground between performance and precision. For more details:
# https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision(cfg.float32_matmul_precision)
if cfg.ckpt:
ckpt_path = resolve_model_checkpoint_path(cfg.ckpt)
if job_num := HydraConfig.get().job.get("num") and cfg.seed is None:
# If Hydra is in multirun mode and no seed is specified by the user (meaning we don't want to reproduce a
# previous experiment and only care about "true" randomness), use the job number as part of the seed to make
# sure to get a different seed for each job. This is a patch since (for some reason I could not figure out)
# some jobs get seeded with the same seed
# The seed is generated under the conditions that:
# i) it is different for each trial (make sure that even if the same initial seed is returned by `randint`,
# the seed will be different for each job)
# ii) it is within the range of values accepted by numpy [np.iinfo(np.uint32).min, np.iinfo(np.uint32).max]
s_min, s_max = np.iinfo(np.uint32).min, np.iinfo(np.uint32).max
seed = random.randint(s_min, s_max)
cfg.seed = ((seed + job_num) % (s_max - s_min)) + s_min
cfg.seed = seed_everything(cfg.seed, workers=True)
experiment_logger = VitalRunner.configure_logger(cfg)
# Instantiate post-processing objects
postprocessors = []
if isinstance(postprocessing_node := cfg.data.get("postprocessing"), DictConfig):
postprocessors = instantiate_config_node_leaves(postprocessing_node, "post-processing")
# Instantiate the different types of callbacks from the configs
callbacks = []
if isinstance(callbacks_node := cfg.get("callbacks"), DictConfig):
callbacks.extend(instantiate_config_node_leaves(callbacks_node, "callback"))
if isinstance(results_processors_node := cfg.get("results_processors"), DictConfig):
callbacks.extend(
instantiate_config_node_leaves(
results_processors_node, "results processor", instantiate_fn=instantiate_results_processor
)
)
if isinstance(predict_node := cfg.data.get("predict"), DictConfig):
logger.info("Instantiating prediction writer")
prediction_writer_kwargs = {}
if postprocessors:
prediction_writer_kwargs["postprocessors"] = postprocessors
callbacks.append(hydra.utils.instantiate(predict_node, **prediction_writer_kwargs))
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=experiment_logger, callbacks=callbacks)
trainer.logger.log_hyperparams(Namespace(**cfg)) # Save config to logger.
if isinstance(trainer.logger, CometLogger):
experiment_logger.experiment.log_asset_folder(".hydra", log_file_name=True)
if cfg.get("comet_tags", None):
experiment_logger.experiment.add_tags(list(cfg.comet_tags))
# Instantiate datamodule
datamodule: VitalDataModule = hydra.utils.instantiate(cfg.data, _recursive_=False)
# Instantiate system (which will handle instantiating the model and optimizer).
model: VitalSystem = hydra.utils.instantiate(
cfg.task, choices=cfg.choices, data_params=datamodule.data_params, _recursive_=False
)
if cfg.ckpt: # Load pretrained model if checkpoint is provided
if cfg.weights_only:
logger.info(f"Loading weights from {ckpt_path}")
model.load_state_dict(torch.load(ckpt_path, map_location=model.device)["state_dict"], strict=cfg.strict)
else:
logger.info(f"Loading model from {ckpt_path}")
model = model.load_from_checkpoint(ckpt_path, data_params=datamodule.data_params, strict=cfg.strict)
if cfg.train:
if cfg.resume:
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
else:
trainer.fit(model, datamodule=datamodule)
if not cfg.trainer.get("fast_dev_run", False):
# Copy best model checkpoint to a predictable path + online tracker (if used)
best_model_path = VitalRunner._best_model_path(model.log_dir, cfg)
if trainer.checkpoint_callback is not None:
copy2(trainer.checkpoint_callback.best_model_path, str(best_model_path))
# Ensure we use the best weights (and not the latest ones) by loading back the best model
model = model.load_from_checkpoint(str(best_model_path))
else: # If checkpoint callback is not used, save current model.
trainer.save_checkpoint(best_model_path)
if isinstance(trainer.logger, CometLogger):
last_model_path = None
if trainer.checkpoint_callback is not None:
best_model_path = trainer.checkpoint_callback.best_model_path
last_model_path = trainer.checkpoint_callback.last_model_path
trainer.logger.experiment.log_model("best-model", best_model_path)
# Also log the `ModelCheckpoint`'s last checkpoint, if it is configured to save one
if last_model_path:
trainer.logger.experiment.log_model("last-model", last_model_path)
if cfg.test:
trainer.test(model, datamodule=datamodule)
if cfg.predict:
trainer.predict(model, datamodule=datamodule)
@staticmethod
def _check_cfg(cfg: DictConfig) -> DictConfig:
"""Parse args, making custom checks on the values of the parameters in the process.
Args:
cfg: Full configuration for the experiment.
Returns:
Validated config for a system run.
"""
# If no output dir is specified, default to the working directory
if not cfg.trainer.get("default_root_dir", None):
with open_dict(cfg):
cfg.trainer.default_root_dir = os.getcwd()
return cfg
@staticmethod
def configure_logger(cfg: DictConfig) -> Union[bool, Logger]:
"""Initializes Lightning logger.
Args:
cfg: Full configuration for the experiment.
Returns:
Logger for the Lightning Trainer.
"""
experiment_logger = True # Default to True (Tensorboard)
skip_logger = False
# Configure custom logger only if user specified custom config
if "logger" in cfg and isinstance(cfg.logger, DictConfig):
if "_target_" not in cfg.logger:
logger.warning("No _target_ in logger config. Cannot instantiate custom logger")
skip_logger = True
if cfg.trainer.get("fast_dev_run", False):
logger.warning(
"Not instantiating custom logger because having `fast_dev_run=True` makes Lightning skip logging. "
"To test the logger, launch a full run."
)
skip_logger = True
if not skip_logger and "_target_" in cfg.logger:
if "comet" in cfg.logger._target_:
experiment_logger = hydra.utils.instantiate(cfg.logger)
elif "tensorboard" in cfg.logger._target_:
# If no save_dir is passed, use default logger and let Trainer set save_dir.
if cfg.logger.get("save_dir", None):
experiment_logger = hydra.utils.instantiate(cfg.logger)
return experiment_logger
@staticmethod
def _best_model_path(log_dir: Path, cfg: DictConfig) -> Path:
"""Defines the path where to copy the best model checkpoint after training.
Args:
log_dir: Lightning's directory for the current run.
cfg: Full configuration for the experiment.
Returns:
Path where to copy the best model checkpoint after training.
"""
if cfg.get("best_model_save_path", None):
return Path(cfg.best_model_save_path) # Return save path from config if available
else:
model = cfg.choices["task/model"]
name = f"{cfg.choices.data}_{cfg.choices.task}"
if model is not None: # Some systems do not have a model (ex. Auto-encoders)
name = f"{name}_{model}"
return log_dir / f"{name}.ckpt"
def main():
"""Run the script."""
VitalRunner.main()
if __name__ == "__main__":
main()