-
Notifications
You must be signed in to change notification settings - Fork 3
/
constants.py
67 lines (53 loc) · 3.02 KB
/
constants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import datetime
############################
## PREPARE ##
############################
crops_p_img = 10 # Number of samples/crops taken per HR image (to get the target output size)
augment_img = True # Augment data with flips (each image will generate an extra flipped image)
############################
## SAVE/LOAD ##
############################
load_model = True # Should we load a saved model from memory, or create a new one?
save_dir = 'save' # Folder name where the model will be saved, relative to root
model_name = 'sobel_model.h5' # Name of the model that is to be loaded and/or saved
# TODO: integrate those two feedbacks
model_json = 'model_architecture.json'
weights = 'model_weights.h5'
def get_model_save_path():
return save_dir + '/' + model_name
############################
## MODEL ##
############################
scale_fact = 4 # resolution multiplication factor
res_blocks = 3 # amount of residual blocks the network has (+1)
############################
## TRAINING ##
############################
# Adjust "crops_p_img", "img_height", "img_width" and "batch_size" to maximize memory usage of GPU.
# "augment_img" will double the amount of pixels calculated below.
# Amount of pixels per batch seen by GPU: "crops_p_img" x "batch_size" x "img_width" x "img_height"
img_width = 64 # size of the output of the network (play around along with batch_size to maximize memory usage)
img_height = 64 # this size divided by the scale_fact is the input size of the network
img_depth = 3 # number of channels (RGB)
epochs = 15 # amount of times the training data is used
batch_size = 2 # amount of images to be cropped and fed per batch
verbosity = 2 # message feedback (0, 1 or 2): higher means more verbose
val_split = 0.1 # percentage of the dataset to be used for validation
hr_img_path = 'dataset/DIV2K/DIV2K/DIV2K_train_HR/' # Where the training dataset is.
#hr_img_path = 'pictures/HR/' # Use this when you want to test the initialization of the filters.
sample_path = 'pictures/HR/' # Path used for the ModelDiagnoser.
############################
## EVALUATION ##
############################
add_callbacks = True # TensorBoard and visualization/diagnostic functionalities (slows down training)
log_dir = './logs' # directory where the Callbacks logs will be stored
tests_path = 'input/' # path to the folder containing the HR images to test with
def get_log_path():
print("Current TensorBoard log directory is: " + log_dir + '/' + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
return log_dir + '/' + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
############################
## PREDICTION ##
############################
input_width = 64 # width size of the input used for prediction
input_height = 64 # height size of the input used for prediction
overlap = 16 # amount of overlapped-pixels for predictions to remove the erroneous edges