-
Notifications
You must be signed in to change notification settings - Fork 15
/
infer.py
87 lines (63 loc) · 2.54 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import os
import os.path as osp
import random
import sys
import time
import cv2
import numpy as np
import torch
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR, StepLR
from torch.utils.data import DataLoader, Dataset
import change_detection_pytorch as cdp
from change_detection_pytorch.datasets import LEVIR_CD_Dataset, SVCD_Dataset
from change_detection_pytorch.datasets.PRCV_CD import PRCV_CD_Dataset
from change_detection_pytorch.utils.lr_scheduler import GradualWarmupScheduler
def seed_torch(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_torch(seed=1024)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)
# test_dataset = PRCV_CD_Dataset('/cache/train_val/val_set',
# sub_dir_1='image1',
# sub_dir_2='image2',
# img_suffix='.png',
# ann_dir='/cache/train_val/val_set/label',
# size=512,
# debug=False,
# test_mode=True)
test_dataset = PRCV_CD_Dataset('/cache/test_set/test_set',
sub_dir_1='image1',
sub_dir_2='image2',
img_suffix='.png',
ann_dir=None,
size=512,
debug=False,
test_mode=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
model_path_1 = './final_models/final_unet.pth'
save_dir = './res'
model = torch.load(model_path_1)
start = time.time()
with torch.no_grad():
model.eval()
for (x1, x2, filename) in test_loader:
x1, x2 = x1.float(), x2.float()
x1, x2 = x1.to(DEVICE), x2.to(DEVICE)
y_pred = model.forward(x1, x2)
if not isinstance(y_pred, torch.Tensor):
y_pred = y_pred[-1]
y_pred = torch.argmax(y_pred, dim=1).squeeze().cpu().numpy().round()
y_pred = y_pred * 255
filename = filename[0].split('.')[0] + '.png'
cv2.imwrite(osp.join(save_dir, filename), y_pred)
end = time.time()
print('time: ', end - start)