forked from chunmeifeng/T2Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ixi_config.yaml
47 lines (39 loc) · 4.04 KB
/
ixi_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#Data parameters
dataset: "IXI" # Dataset type. currently only IXI supported.
data_dir: "/home/jc3/Data/IXI_T2/h5/"
train_data_dir: "/home/jc3/Data/IXI_T2/mat/train" # Training files dir, should contain hdf5 preprocessed data.
val_data_dir: "/home/jc3/Data/IXI_T2/mat/val" # Validation files dir, should contain hdf5 preprocessed data.
output_dir: "./logs" # Directory to save checkpoints and tensorboard data.
sampling_percentage: 30 # Sampling mask precentage (provided with the code 20%,30% and 50% sampling masks of 256X256).
num_input_slices: 1 # Num of slices to use for input (3 means the predicted slice + previous slice + next slice).
img_size: 256 # Input image size (256X256 for IXI).
# img_size2: 640
slice_range: [20,120] # Slices to use for training data.
#Load checkpoint
load_cp: 0 # 0 to start a new training or checkpoint path to load network weights.
resume_training: 1 # 0 - Load only model weights , 1 - Load Weights + epoch number + optimizer and scheduler state.
#Networks parameters
bilinear: 1 # 1 - Use bilinear upsampling , 0 - Use up-conv.
crop_center: 128 # Discriminator center crop size (128X128), to avoid classifying blank patches.
#Training parameters
lr: 0.001 # Learning rate
epochs_n: 50 # Number of epochs
batch_size: 64 # Batch size. Reduce if out of memory.Batch size of 32 256X256 images needs ~13GB memory.
GAN_training: 1 # 1 - Use GAN training. 0 - No GAN (no discriminator training and adverserial loss)
loss_weights: [1000, 1000, 5, 0.1 ] # Loss weighting [Imspac L2, Imspace L1, Kspace L2, GAN_Loss]. Losses are weighted to be roughly at the same scale.
minmax_noise_val: [-0.01, 0.01]
#Tensorboard
tb_write_losses: 1 # Write losses and scalars to Tensorboard.
tb_write_images: 1 # Write images to Tensorboard.
#Runtime
device: 'cuda' # For GPU training : 'cuda', for CPU training (not recomended!) 'cpu'.
gpu_id: '0' # GPU ID to use.
train_num_workers: 16 # Number of training dataset workers. Reduce if you are getting a shared memory error.
val_num_workers: 4 # Number of validation dataset workers. Reduce if you are getting a shared memory error.
#Predict parameters
save_prediction: 1 # Save predicted images.
save_path: "home/jc3/multiSR/JS_fastMRI/save_predictions/T2T1_Unet" # Path to save predictions
visualize_images: 1 # Visualize predicted images.
model: "model_path.pth" # Model checkpoint to use for prediction.
predict_data_dir: "/home/jc3/Data/IXI_T2/h5/test" # Test set files dir, should contain hdf5 preprocessed data.
# jc3/multiSR/JS_fastMRI/SR_fastMRI-master/experimental/i