-
Notifications
You must be signed in to change notification settings - Fork 8
/
app_gui.py
464 lines (402 loc) · 13.4 KB
/
app_gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
'''
Created on :2021/02/18 20:27:53
@author :Caihao (Chris) Cui
@file :app_gui.py
@content :xxx xxx xxx
@version :0.1
@License : (C)Copyright 2020 MIT
'''
"""
A simple Gooey example. One required field, one optional.
"""
# from __future__ import print_function
from matplotlib import style
from predict import predict
from predict import metricComputation
from train import train
from gooey import Gooey, GooeyParser
import os
import json
import utils
import dataset
from model import FCNN
from datetime import datetime
from datetime import date
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import cv2 as cv
from PIL import Image
from matplotlib import pyplot as plt
import matplotlib
matplotlib.use("tkagg")
style.use("ggplot")
# %% GUI Design
# add tran function
@Gooey(
program_name="Deep Learning Aerial Image Labelling",
default_size=(640, 680),
advanced=True,
# progress_regex=r"^(Epoch ((\d+)\/(\d+)))(.*)]$", # not working
progress_regex=r"(\d+)%",
tabbed_groups=True,
navigation="Tabbed",
# dump_build_config=True,
# load_build_config=True,
# hide_progress_msg=False,
# timing_options={
# 'show_time_remaining': True,
# 'hide_time_remaining_on_complete': True},
menu=[
{
"name": "File",
"items": [
{
"type": "AboutDialog",
"menuTitle": "About",
"name": "DL Aerial Image Labelling",
"description": "ConvNets for Aerial Image Labelling: Test Case",
"version": "1.0.0",
"copyright": "2021",
"website": "https://cuicaihao.com",
"developer": "Chris.Cui",
"license": "MIT",
},
{
"type": "MessageDialog",
"menuTitle": "Information",
"caption": "My Message",
"message": "Hello Deep Learning, this is demo.",
},
{
"type": "Link",
"menuTitle": "Visit My GitLab",
"url": "https://github.com/cuicaihao",
},
],
},
{
"name": "Help",
"items": [
{
"type": "Link",
"menuTitle": "Documentation",
"url": "https://github.com/cuicaihao/aerial-image-segmentation",
}
],
},
],
)
def parse_args():
"""Use GooeyParser to build up the arguments we will use in our script
Save the arguments in a default json file so that we can retrieve them
every time we run the script.
"""
stored_args = {}
# get the script name without the extension & use it to build up
# the json filename
script_name = os.path.splitext(os.path.basename(__file__))[0]
args_file = "{}-args.json".format(script_name)
# Read in the prior arguments as a dictionary
if os.path.isfile(args_file):
with open(args_file) as data_file:
stored_args = json.load(data_file)
# return stored_args
settings_msg = (
"example demonstating aerial image labelling" "for house, road, and buildings."
)
parser = GooeyParser(description=settings_msg)
#
IO_files_group = parser.add_argument_group(
"Data IO", gooey_options={"show_border": False, "columns": 1}
)
IO_files_group.add_argument(
"input_RGB",
type=str,
metavar="Input RGB Image",
action="store",
# default="images/case_03/RGB.png",
default=stored_args.get('input_RGB'),
help="string of RGB image file path",
widget="FileChooser",
)
IO_files_group.add_argument(
"input_GT",
type=str,
metavar="Input Ground True Image",
action="store",
widget="FileChooser",
# default="images/case_03/GT.png",
default=stored_args.get('input_GT'),
help="string of Ground Truce (GT image file path",
)
IO_files_group.add_argument(
"output_model_path",
type=str,
metavar="Output/Reload Model File",
# default="weights/CapeTown.model.weights.pt",
default=stored_args.get('output_model_path'),
help="saved file path",
widget="FileChooser",
)
IO_files_group.add_argument(
"output_loss_plot",
metavar="Output Dev History Plot",
type=str,
# default="output/loss_plot.png",
default=stored_args.get('output_loss_plot'),
help="save the training error curves",
widget="FileChooser",
)
IO_files_group.add_argument(
"output_images",
metavar="Output Image Folder",
type=str,
# default="output/",
default=stored_args.get('output_images'),
help="string of output image file path",
widget="DirChooser",
)
config_group = parser.add_argument_group(
"Model Option", gooey_options={"show_border": False, "columns": 2}
)
config_group.add_argument(
"--running_time",
metavar="Time (hh:mm:ss)",
# default="12:34:56",
default=datetime.now().strftime("%H:%M:%S"),
help="App Started Time",
widget="TimeChooser",
)
config_group.add_argument(
"--running_date",
metavar="Data (yyyy-mm-dd)",
# default="2021-01-01",
default=date.today().strftime("%Y-%m-%d"),
help="App Started Date",
widget="DateChooser",
)
config_group.add_argument(
"--use_gpu",
default=False,
metavar="Enable GPU",
help="Use GPU in Traning (default: False)",
action="store_true",
widget="CheckBox",
)
config_group.add_argument(
"--use_pretrain",
default=False,
metavar="Use Pretrained Model",
help="Use Pre-Trained Model (default: False)",
action="store_true",
widget="CheckBox",
)
# train_group = parser.add_argument_group(
# "ConvNets Training Optimization Parameters",
# gooey_options={'show_border': True, 'columns': 1})
# config_group.add_argument('--tile_size', metavar='Tile Size', nargs=2,
# default='250, 250', type=int,
# help='input tile size ', widget='Dropdown', choices=['100, 100', '250, 250'])
config_group.add_argument(
"--tile_size_height",
metavar="Tile Size: Height",
default=250,
type=int,
help="input tile size (width)",
widget="Dropdown",
choices=[100, 250, 500],
)
config_group.add_argument(
"--tile_size_width",
metavar="Tile Size: Width",
default=250,
type=int,
help="input tile size (height)",
widget="Dropdown",
choices=[100, 250, 500],
)
config_group.add_argument(
"--epochs",
metavar="Epoch Number",
default=1,
type=int,
help="epoch number: positive integer",
# choices=[1, 2, 5, 200, 400, 800, 1600],
# widget="Dropdown",
widget="IntegerField",
)
config_group.add_argument(
"--batch_size",
metavar="Batch Size",
default=4,
type=int,
help="batch size (?, 3, W, H)",
choices=[1, 2, 4, 8, 16, 32, 64],
widget='Dropdown'
# widget="IntegerField",
)
config_group.add_argument(
"--learning_rate",
metavar="Learning Rate",
default=1e-4,
type=float,
help="Adam learning rate (0.0, 1.0) ",
choices=[0.1, 1e-2, 1e-3, 1e-4],
widget="Dropdown"
# widget='DecimalField',
# gooey_options={
# 'validator': {
# 'test': '0 < int(user_input) <= 1',
# 'message': 'Must be between 0 and 1'
# }
# }
)
config_group.add_argument(
"--weight_decay",
metavar="Weight Decay",
default=5e-3,
type=float,
help="model weight decay/l2-norm regularization",
widget="DecimalField"
# choices=[5e-3, 5e-3, 5e-4],
# widget='Dropdown'
)
# parser.add_argument('--version', '-v', action='version',
# version='%(prog)s 1.0.0')
args = parser.parse_args()
# Store the values of the arguments so we have them next time we run
with open(args_file, "w") as data_file:
# Using vars(args) returns the data as a dictionary
json.dump(vars(args), data_file)
return args
def dev_model(args): # modified from __main__ in train.py
# Get the arguments from GUI
INPUT_IMAGE_PATH = args.input_RGB
LABEL_IMAGE_PATH = args.input_GT
WEIGHTS_FILE_PATH = args.output_model_path
LOSS_PLOT_PATH = args.output_loss_plot
use_gpu = args.use_gpu
use_pretrain = args.use_pretrain
epochs = args.epochs
batch_size = args.batch_size
tile_size = (args.tile_size_height, args.tile_size_width)
learning_rate = args.learning_rate
weight_decay = args.weight_decay
device = utils.device(use_gpu=use_gpu)
# init model structure
model = FCNN()
# model = utils.load_weights_from_disk(model)
if use_pretrain:
model = utils.load_entire_model(model, WEIGHTS_FILE_PATH, use_gpu)
print("use pretrained model!")
train_loader = dataset.training_loader(
image_path=INPUT_IMAGE_PATH,
label_path=LABEL_IMAGE_PATH,
batch_size=batch_size,
tile_size=tile_size,
shuffle=True, # use shuffle
) # turn the shuffle
model, stats = train(
model=model,
train_loader=train_loader,
device=device,
epochs=epochs,
batch_size=batch_size,
tile_size=tile_size,
learning_rate=learning_rate,
weight_decay=weight_decay,
)
# model_path = utils.save_weights_to_disk(model)
model_path = utils.save_entire_model(model, WEIGHTS_FILE_PATH)
# save the loss figure and data
stats.save_loss_plot(LOSS_PLOT_PATH)
print("[>>>] Passed!")
def dev_predit(args):
use_gpu = args.use_gpu
tile_size = tile_size = (args.tile_size_height, args.tile_size_width)
INPUT_IMAGE_PATH = args.input_RGB
LABEL_IMAGE_PATH = args.input_GT
WEIGHTS_FILE_PATH = args.output_model_path
LOSS_PLOT_PATH = args.output_loss_plot
OUTPUT_IMAGE_PATH = args.output_images
# Step 02: Get Input Resources and Model Configuration
device = utils.device(use_gpu=use_gpu)
model = FCNN()
# model = utils.load_weights_from_disk(model)
model = utils.load_entire_model(model, WEIGHTS_FILE_PATH, use_gpu)
# print(model)
# summary(model, (3, tile_size[0], tile_size[1]))
# this is issue !!!
loader = dataset.full_image_loader(
INPUT_IMAGE_PATH, LABEL_IMAGE_PATH, tile_size=tile_size
)
prediction = predict(
model, loader, device=device, class_label=utils.ClassLabel.house
)
# Step 03: save the output
input_image = utils.input_image(INPUT_IMAGE_PATH)
pred_image, mask_image = utils.overlay_class_prediction(
input_image, prediction)
pred_image_path = OUTPUT_IMAGE_PATH + "/prediction.png"
pred_image.save(pred_image_path)
pred_mask_path = OUTPUT_IMAGE_PATH + "/mask.png"
mask_image.save(pred_mask_path)
print("(i) Prediction and Mask image saved at {}".format(pred_image_path))
print("(ii) Prediction and Mask image saved at {}".format(pred_mask_path))
# Show Metrics Computation
img_gt = np.array(Image.open(LABEL_IMAGE_PATH), dtype=np.int32)
img_mask = np.array(Image.open(pred_mask_path), dtype=np.int32)
metricComputation(img_gt, img_mask)
# show images
img_rgb = cv.imread(INPUT_IMAGE_PATH)
img_gt = cv.imread(LABEL_IMAGE_PATH)
img_pred = cv.imread(pred_mask_path) # pred_image_path
img_lost = cv.imread(LOSS_PLOT_PATH)
images = [img_rgb, img_gt, img_pred, img_lost]
titles = ["RGB", "GT", "Prediction", "Training Loss"]
plt.figure(num=None, figsize=(7, 7), dpi=80, facecolor="w", edgecolor="k")
for i in range(4):
plt.subplot(
2, 2, i + 1), plt.imshow(images[i], "gray", vmin=0, vmax=255)
plt.title(titles[i])
plt.xticks([]), plt.yticks([])
plt.show()
return pred_image_path, pred_mask_path
def config_checking(conf):
if conf.epochs < 0:
return False
return True
def main():
conf = parse_args()
print("=" * 40)
now = datetime.now()
start_time = now.strftime("%Y/%m/%d %H:%M:%S")
print(f"model start: {start_time}")
print("=" * 40)
for arg in vars(conf):
print("{}:{}".format(arg, getattr(conf, arg)))
# config checking!
if config_checking(conf):
# train model
if conf.epochs > 0:
dev_model(conf) # comment this line for GUI Design
dev_predit(conf) # train and predict
else:
# get training output
dev_predit(conf)
else:
print("Wrong Option")
print("=" * 40)
now = datetime.now()
end_time = now.strftime("%Y/%m/%d %H:%M:%S")
print(f"model start: {start_time} end: {end_time}.")
print("=" * 40)
if __name__ == "__main__":
main()
print("\r" * 3)
# pythonw app_gui.py
# wxPython on Mac within a virtual environment throws this error,
# as explained by wxPython website here:
# https://wiki.wxpython.org/wxPythonVirtualenvOnMac