forked from GZHermit/tensorflow-GAN4Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
val_multitask_include_d.py
133 lines (106 loc) · 5.65 KB
/
val_multitask_include_d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# _*_ coding:utf-8
import os
import time
import numpy as np
import tensorflow as tf
from models.discriminator_multitask import choose_discriminator
from models.generator_multitask import choose_generator
from utils.data_handle import load_weight
from utils.image_process import inv_preprocess, decode_labels
from utils.image_reader import read_labeled_image_list
def convert_to_scaling(score_map, num_classes, label_batch, tau=0.9):
score_map_max = tf.reduce_max(score_map, axis=3, keep_dims=False)
y_il = tf.maximum(score_map_max, tf.fill(tf.shape(label_batch)[:-1], tau))
_s_il = 1.0 - score_map_max
_y_il = 1.0 - y_il
a = tf.expand_dims(tf.div(_y_il, _s_il), axis=3)
y_ic = tf.concat([a for i in range(num_classes)], axis=3)
y_ic = tf.multiply(score_map, y_ic)
b = tf.expand_dims(y_il, axis=3)
y_il_ = tf.concat([b for i in range(num_classes)], axis=3)
lab_hot = tf.squeeze(tf.one_hot(label_batch, num_classes, dtype=tf.float32), axis=3)
gt_batch = tf.where(tf.equal(lab_hot, 1.), y_il_, y_ic)
gt_batch = tf.clip_by_value(gt_batch, 0., 1.)
return gt_batch
def convert_to_calculateloss(raw_output, label_batch, num_classes):
raw_output = tf.image.resize_bilinear(raw_output, tf.shape(label_batch)[1:3])
raw_groundtruth = tf.reshape(tf.squeeze(label_batch, squeeze_dims=[3]), [-1, ])
raw_prediction = tf.reshape(raw_output, [-1, num_classes])
indices = tf.squeeze(tf.where(tf.less_equal(raw_groundtruth, num_classes - 1)), 1)
label = tf.cast(tf.gather(raw_groundtruth, indices), tf.int32) # [?, ]
logits = tf.gather(raw_prediction, indices) # [?, num_classes]
return label, logits
def get_validate_data(image_name, label_name, img_mean):
img = tf.read_file(image_name)
img = tf.image.decode_jpeg(img, channels=3)
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
img_normal = (img + img_mean) / 255.
img_normal = tf.expand_dims(img_normal, axis=0)
img -= img_mean
img = tf.expand_dims(img, axis=0)
label = tf.read_file(label_name)
label = tf.image.decode_png(label, channels=1)
label = tf.expand_dims(label, axis=0)
return img, img_normal, label
def val(args):
## set hyparameter
img_mean = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
tf.set_random_seed(args.random_seed)
## load data
image_list, label_list, png_list = read_labeled_image_list(args.data_dir, is_val=True,
valid_image_store_path=args.valid_image_store_path)
num_val = len(image_list)
image_name = tf.placeholder(dtype=tf.string)
label_name = tf.placeholder(dtype=tf.string)
png_name = tf.placeholder(dtype=tf.string)
image_batch, image_normal_batch, label_batch = get_validate_data(image_name, label_name, img_mean)
print("data load completed!")
## load model
g_net, g_net_x = choose_generator(args.g_name, image_batch, image_normal_batch)
score_map = g_net.get_output()
fk_batch = tf.nn.softmax(score_map, dim=-1)
gt_batch = tf.image.resize_nearest_neighbor(label_batch, tf.shape(score_map)[1:3])
gt_batch = convert_to_scaling(fk_batch, args.num_classes, gt_batch)
x_batch = g_net_x.get_appointed_layer('generator/image_conv5_3')
d_fk_net, d_gt_net = choose_discriminator(args.d_name, fk_batch, gt_batch, x_batch)
d_fk_pred = d_fk_net.get_output() # fake segmentation result in d
d_gt_pred = d_gt_net.get_output() # ground-truth result in d
predict_batch = g_net.topredict(score_map, tf.shape(label_batch)[1:3])
predict_img = tf.write_file(png_name,
tf.image.encode_png(tf.cast(tf.squeeze(predict_batch, axis=0), dtype=tf.uint8)))
labels, logits = convert_to_calculateloss(score_map, label_batch, args.num_classes)
pre_labels = tf.argmax(logits, 1)
print("Model load completed!")
iou, iou_op = tf.metrics.mean_iou(labels, pre_labels, args.num_classes, name='iou')
acc, acc_op = tf.metrics.accuracy(labels, pre_labels)
m_op = tf.group(iou_op, acc_op)
image = tf.py_func(inv_preprocess, [image_batch, args.save_num_images, img_mean], tf.uint8)
label = tf.py_func(decode_labels, [label_batch, ], tf.uint8)
pred = tf.py_func(decode_labels, [predict_batch, ], tf.uint8)
tf.summary.image(name='img_collection_val', tensor=tf.concat([image, label, pred], 2))
tf.summary.scalar(name='iou_val', tensor=iou)
tf.summary.scalar(name='acc_val', tensor=acc)
tf.summary.scalar('fk_score', tf.reduce_mean(tf.sigmoid(d_fk_pred)))
tf.summary.scalar('gt_score', tf.reduce_mean(tf.sigmoid(d_gt_pred)))
sum_op = tf.summary.merge_all()
sum_writer = tf.summary.FileWriter(args.log_dir, max_queue=20)
sess = tf.Session()
global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
sess.run(global_init)
sess.run(local_init)
saver = tf.train.Saver(var_list=tf.global_variables())
_ = load_weight(args.restore_from, saver, sess)
if not os.path.exists(args.valid_image_store_path):
os.makedirs(args.valid_image_store_path)
print("validation begining")
for step in range(num_val):
it = time.time()
dict = {image_name: image_list[step], label_name: label_list[step], png_name: png_list[step]}
_, _, iou_val = sess.run([m_op, predict_img, iou], dict)
if step % 50 == 0 or step == num_val - 1:
summ = sess.run(sum_op, dict)
sum_writer.add_summary(summ, step)
print("step:{},time:{},iou:{}".format(step, time.time() - it, iou_val))
print("end......")