-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
70 lines (57 loc) · 1.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import time
import comet_ml
import hydra
import lightning as pl
import torch
from dataset import EarthViewNEONDatamodule
from lightning.pytorch.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from lightning.pytorch.loggers import CometLogger
from lightning_model import DepthAnythingV2Module
from omegaconf import DictConfig
@hydra.main(config_path="configs", config_name="default", version_base=None)
def main(args: DictConfig):
pl.seed_everything(42)
torch.set_float32_matmul_precision("medium")
data_module = EarthViewNEONDatamodule(**args.dataset)
model = DepthAnythingV2Module(**args.model)
experiment_id = time.strftime("%Y%m%d-%H%M%S")
logger = False
if args.logger:
logger = CometLogger(
project_name="depth-any-canopy",
workspace="",
experiment_name="",
save_dir="comet-logs",
offline=False,
)
experiment_id = logger.experiment.id
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath=f"checkpoints/{experiment_id}",
filename="depth-any-canopy-{epoch:02d}-{val_loss:.2f}",
save_top_k=3,
mode="min",
)
early_stopping = EarlyStopping(
monitor="val_loss", patience=10, mode="min", verbose=True
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callback = [checkpoint_callback, early_stopping]
if logger:
callback.append(lr_monitor)
trainer = pl.Trainer(
**args.trainer,
logger=logger,
callbacks=callback,
log_every_n_steps=50,
precision="32-true" if args.model.encoder == "vitl" else "32-true",
limit_val_batches=50,
val_check_interval=500,
)
trainer.fit(model, datamodule=data_module)
if __name__ == "__main__":
main()