-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
65 lines (51 loc) · 2.34 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset_utils import WaveUIRTrainDataset
from net.WaveUIR import WaveUIR
from utils.schedulers import LinearWarmupCosineAnnealingLR
import numpy as np
import wandb
from options import options as opt
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger,TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
class WaveUIRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = WaveUIR(dim = 32, num_blocks = [2,3,3,4], num_refinement_blocks = 2, task_num=opt.task_num)
self.loss_fn = nn.L1Loss()
def forward(self,x):
return self.net(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
([clean_name, class_id], degrad_patch, clean_patch) = batch
# restored, prompt_weights = self.net(degrad_patch)
restored = self.net(degrad_patch)
loss = self.loss_fn(restored,clean_patch)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def lr_scheduler_step(self,scheduler,metric):
scheduler.step(self.current_epoch)
lr = scheduler.get_lr()
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=opt.lr)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,warmup_epochs=15, max_epochs=150)
return [optimizer],[scheduler]
def main():
print("Options")
print(opt)
logger = TensorBoardLogger(save_dir = "logs/")
trainset = WaveUIRTrainDataset(opt)
checkpoint_callback = ModelCheckpoint(dirpath = opt.ckpt_dir,every_n_train_steps = 4444, save_last = True, save_top_k=-1)
trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True, shuffle=True,
drop_last=True, num_workers=opt.num_workers)
model = WaveUIRModel()
trainer = pl.Trainer( max_epochs=opt.epochs,accelerator="gpu",devices=opt.num_gpus,strategy="ddp_find_unused_parameters_true",logger=logger,callbacks=[checkpoint_callback])
trainer.fit(model=model, train_dataloaders=trainloader)
if __name__ == '__main__':
main()