-
Notifications
You must be signed in to change notification settings - Fork 70
/
predict.py
93 lines (76 loc) · 2.69 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import CustomObjectScope
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from utils import *
from train import tf_dataset
def read_image(x):
image = cv2.imread(x, cv2.IMREAD_COLOR)
image = np.clip(image - np.median(image)+127, 0, 255)
image = image/255.0
image = image.astype(np.float32)
image = np.expand_dims(image, axis=0)
return image
def read_mask(y):
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
mask = mask.astype(np.float32)
mask = mask/255.0
mask = np.expand_dims(mask, axis=-1)
return mask
def mask_to_3d(mask):
mask = np.squeeze(mask)
mask = [mask, mask, mask]
mask = np.transpose(mask, (1, 2, 0))
return mask
def parse(y_pred):
y_pred = np.expand_dims(y_pred, axis=-1)
y_pred = y_pred[..., -1]
y_pred = y_pred.astype(np.float32)
y_pred = np.expand_dims(y_pred, axis=-1)
return y_pred
def evaluate_normal(model, x_data, y_data):
THRESHOLD = 0.5
total = []
for i, (x, y) in tqdm(enumerate(zip(x_data, y_data)), total=len(x_data)):
x = read_image(x)
y = read_mask(y)
_, h, w, _ = x.shape
y_pred1 = parse(model.predict(x)[0][..., -2])
y_pred2 = parse(model.predict(x)[0][..., -1])
line = np.ones((h, 10, 3)) * 255.0
all_images = [
x[0] * 255.0, line,
mask_to_3d(y) * 255.0, line,
mask_to_3d(y_pred1) * 255.0, line,
mask_to_3d(y_pred2) * 255.0
]
mask = np.concatenate(all_images, axis=1)
cv2.imwrite(f"results/{i}.png", mask)
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = tf.keras.layers.Flatten()(y_true)
y_pred_f = tf.keras.layers.Flatten()(y_pred)
intersection = tf.reduce_sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
def dice_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred)
if __name__ == "__main__":
np.random.seed(42)
tf.random.set_seed(42)
create_dir("results/")
batch_size = 8
test_path = "../1/new_data/test/"
test_x = sorted(glob(os.path.join(test_path, "image", "*.jpg")))
test_y = sorted(glob(os.path.join(test_path, "mask", "*.jpg")))
test_dataset = tf_dataset(test_x, test_y, batch=batch_size)
test_steps = (len(test_x)//batch_size)
if len(test_x) % batch_size != 0:
test_steps += 1
model = load_model_weight("files/model.h5")
model.evaluate(test_dataset, steps=test_steps)
evaluate_normal(model, test_x, test_y)