-
Notifications
You must be signed in to change notification settings - Fork 2
/
validate.py
66 lines (54 loc) · 1.9 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import time
from pathlib import Path
from statistics import mean
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from config import (CRITERION, LARGE_BATCH_SIZE, NUMBER_PATCH_PER_IMAGE,
PADDING, PATCH_SIZE)
from config import TEST_BATCH_SIZE as BATCH_SIZE
from config import TRAIN_DATASET_DIR as DATASET_DIR
from config import TEST_MODEL as MODEL
from config import TEST_MODEL_WEIGTS as MODEL_WEIGHTS
from config import TRAIN_IMAGE_SIZE as IMAGE_SIZE
from datasets import RoadsDatasetValidation
from models.resnet import ResNet
from models.unet import UNet
def validate(model, dataloader, criterion, model_weights=None):
if model_weights is not None:
model.load_state_dict(torch.load(model_weights))
cuda = torch.cuda.is_available()
if cuda:
model = model.to(device="cuda")
print("CUDA available")
else:
print("NO CUDA")
model.eval()
global_loss = []
for ind_batch, sample_batched in enumerate(dataloader):
images = sample_batched["image"]
groundtruths = sample_batched["groundtruth"]
if cuda:
images = images.to(device="cuda")
groundtruths = groundtruths.to(device="cuda")
output = model(images)
loss = criterion(output, groundtruths)
global_loss.append(loss)
print("[Validation Loss: {:03.2f}]".format(mean(global_loss)))
if __name__ == "__main__":
model = MODEL
dataset = RoadsDatasetValidation(
patch_size=PATCH_SIZE,
large_patch_size=PATCH_SIZE,
image_initial_size=IMAGE_SIZE,
number_patch_per_image=NUMBER_PATCH_PER_IMAGE,
root_dir=DATASET_DIR,
)
dataloader = data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=False)
validate(
model=model,
dataloader=dataloader,
criterion=CRITERION,
model_weights=MODEL_WEIGHTS,
)