-
Notifications
You must be signed in to change notification settings - Fork 3
/
TRAINING_CONFIG.py
31 lines (23 loc) · 1011 Bytes
/
TRAINING_CONFIG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
# TRAINING CONFIG
batch_size = 1
lr = 0.0002
num_epochs = 50
step_size = 400
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Training Image Directories(input / label) & training image size
raw_image_path = './data/input/'
clear_image_path = './data/label/'
train_img_size = 256 #KEEP TRAINING IMG SIZE to 256x256! Customized UNET architecture is tailored to this + faster training while outputting good results
# Saving Training Checkpoints
snapshots_folder = './snapshots/unetSSIM'
snapshot_freq = 5
model_name = 'unetSSIM'
# Testing Image Directories(input / output) & test image size
test_image_path = './data/test_imgs/'
output_images_path = './data/test_output/unetssim/'
test_img_size = 512
# Enter checkpoint filepath if i'm resuming training (DO NOT ENTER MODEL.CKPT FILES!)
ckpt_path = './snapshots/unetDROPn/model_epoch_49_unetDROPn.ckpt'
# Enter model path for TESTING (ENTER MODEL.CKPT FILES!)
test_model_path = './model_ckpt/deep_seann_FINALMODEL.ckpt'