-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
162 lines (138 loc) · 5.22 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
import torchvision
from model import UNetWithSkipv2
class UNetTrainer(pl.LightningModule):
"""Pytorch Lightning triner class for the pixel-wise classification
task of active fire detection."""
def __init__(
self,
input_ch=10,
use_act=None,
enc_ch=(32, 64, 128, 256, 512, 1024),
lr=1e-4,
tb_log_pred_gt=False,
):
"""Initialize the UNetTrainer class.
Args:
input_ch (int, optional): number of input channels. Defaults to 10.
use_act (nn.module, optional): activation function. Defaults to None.
enc_ch (tuple, optional): encoder channels. Defaults to (32, 64, 128, 256, 512, 1024).
lr (float, optional): learning rate. Defaults to 1e-4.
tb_log_pred_gt (bool, optional): whether to plot predictions and annotations in tboard. Defaults to False.
""" # noqa: E501
super().__init__()
self.mse_err = torchmetrics.MeanSquaredError()
self.mae_err = torchmetrics.MeanAbsoluteError()
self.iou_err = torchmetrics.JaccardIndex(task="binary")
self.acc = torchmetrics.classification.Accuracy(task="binary")
self.lr = lr
self.tb_log_pred_gt = tb_log_pred_gt
self.loss = nn.BCEWithLogitsLoss()
self.model = UNetWithSkipv2(
input_ch=input_ch,
use_act=use_act,
encoder_channels=enc_ch,
)
self.validation_preds = []
self.validation_targets = []
def forward(self, x):
"""Run forward pass."""
return self.model(x)
def training_step(self, batch, batch_idx):
"""Define the training step."""
x, y, _ = batch
x = x.to(memory_format=torch.channels_last)
x = x.float()
y = y.float()
preds = self.model.forward(x)
loss = self.loss(preds, y)
pred_mask = (preds > 0.5).float()
mse_error = self.mse_err(pred_mask, y)
iou_error = self.iou_err(pred_mask, y)
accuracy = self.acc(pred_mask, y)
metrics = {
"train_loss": loss,
"train_mse_err": mse_error,
"train_iou_err": iou_error,
"train_accuracy": accuracy,
}
self.log_dict(metrics, logger=True, prog_bar=True, sync_dist=True)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
"""Define the validation step."""
x, y, _ = batch
x = x.to(memory_format=torch.channels_last)
x = x.float()
y = y.float()
preds = self.model.forward(x)
loss = self.loss(preds, y)
pred_mask = (preds > 0.5).float()
mse_error = self.mse_err(pred_mask, y)
iou_error = self.iou_err(pred_mask, y)
accuracy = self.acc(pred_mask, y)
self.validation_preds.append(pred_mask)
self.validation_targets.append(y)
metrics = {
"val_loss": loss,
"val_mse_err": mse_error,
"val_iou_err": iou_error,
"val_accuracy": accuracy,
}
self.log_dict(metrics, logger=True, prog_bar=True, sync_dist=True)
def on_validation_epoch_end(self):
"""Define the prediction and annotation plotting step after validation."""
if self.tb_log_pred_gt:
preds = torch.cat(self.validation_preds, dim=0)
targets = torch.cat(self.validation_targets, dim=0)
grid_preds = torchvision.utils.make_grid(preds)
grid_targets = torchvision.utils.make_grid(targets)
self.logger.experiment.add_image(
"predictions", grid_preds, self.current_epoch
)
self.logger.experiment.add_image(
"targets", grid_targets, self.current_epoch
)
self.validation_preds.clear()
self.validation_targets.clear()
@torch.no_grad()
def test_step(self, batch, batch_idx):
"""Define the test step."""
x, y, _ = batch
x = x.to(memory_format=torch.channels_last)
x = x.float()
y = y.float()
preds = self.model.forward(x)
loss = self.loss(preds, y)
pred_mask = (preds > 0.5).float()
mse_error = self.mse_err(pred_mask, y)
iou_error = self.iou_err(pred_mask, y)
accuracy = self.acc(pred_mask, y)
metrics = {
"test_mse_err": mse_error,
"test_iou_err": iou_error,
"test_loss": loss,
"test_accuracy": accuracy,
}
self.log_dict(metrics, logger=True, prog_bar=True, sync_dist=True)
def configure_optimizers(self, use_lr_scheduler=True):
"""Configure the optimizer."""
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if use_lr_scheduler:
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer=optimizer,
T_0=3,
T_mult=1,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "epoch",
},
}
else:
return optimizer