-
Notifications
You must be signed in to change notification settings - Fork 0
/
script.py
85 lines (68 loc) · 3.91 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from PIL import Image
from model import Encoder, Decoder, Decoder3
from training import train
from torch import nn
from dataset.data_generator import Data_Generator, train_augmentations, val_augmentations
from torch.utils.tensorboard import SummaryWriter
import click
import os
@click.group()
def cli():
pass
@cli.command()
@click.option('--experiment_name', type=str, default='', show_default=True)
@click.option('--model_name', required=True, type=str, show_default=True)
@click.option('--resume', is_flag=True, show_default=True)
@click.option('--batch_size', required=True, type=int, default=16, show_default=True)
@click.option('--gpu_id', type=str, default='0', show_default=True)
@click.option('--learning_rate', type=float, default=1e-3, show_default=True)
@click.option('--epochs', type=int, default=20, show_default=True)
@click.option('--checkpoint_path', type=str, default='checkpoints/', show_default=True)
@click.option('--checkpoint_name', type=str)
@click.option('--dataset_name', required=True, show_default=True)
def training(experiment_name, model_name, resume, batch_size, gpu_id, learning_rate, epochs, checkpoint_path, checkpoint_name, dataset_name):
kwargs = locals()
print(kwargs)
experiment_name= experiment_name
writer = SummaryWriter('runs/'+experiment_name)
checkpoint_path = os.path.join(checkpoint_path, experiment_name)
os.makedirs(checkpoint_path, exist_ok=True)
if gpu_id is not None:
device = torch.device('cuda:'+str(gpu_id) if torch.cuda.is_available() else 'cpu')
else:
device = torch.device('cpu')
if model_name == 'dct':
flag_dct = True
flag_sb = False
model = Decoder()
if model_name == 'sb':
flag_dct = False
flag_sb = True
model = Decoder()
if model_name == 'dct_sb':
flag_dct = True
flag_sb = True
model = Decoder3()
if resume:
checkpoint = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(checkpoint['model'])
if dataset_name == 'synthetic':
train_data = Data_Generator('/home/siopi/drive2/Synthetic manipulation dataset/', '/home/siopi/drive2/Synthetic manipulation dataset/synthetic_sb.csv', transform=train_augmentations, dct=flag_dct, sb=flag_sb, inverse=True)
val_data = Data_Generator('/home/siopi/drive2/Synthetic manipulation dataset/', '/home/siopi/drive2/Synthetic manipulation dataset/synthetic_sb.csv', transform=val_augmentations, split=['val'], dct=flag_dct, sb=flag_sb, inverse=True)
if dataset_name == 'casia2':
train_data = Data_Generator('/home/siopi/drive2/CASIA2/', '/home/siopi/drive2/CASIA 2/CASIA 2.0/casia2.csv', transform=train_augmentations, dct=flag_dct, sb=flag_sb)
val_data = Data_Generator('/home/siopi/drive2/CASIA2/', '/home/siopi/drive2/CASIA 2/CASIA 2.0/casia2.csv', transform=val_augmentations, split=['val'], dct=flag_dct, sb=flag_sb)
if dataset_name == 'ifs_tc':
train_data = Data_Generator('/home/siopi/drive2/IFS-TC/', '/home/siopi/drive2/IFS-TC/ifstc.csv', transform=train_augmentations, dct=flag_dct, sb=flag_sb, inverse=True)
val_data = Data_Generator('/home/siopi/drive2/IFS-TC/', '/home/siopi/drive2/IFS-TC/ifstc.csv', transform=val_augmentations, split=['val'], dct=flag_dct, sb=flag_sb, inverse=True)
train_loader = DataLoader(dataset = train_data, shuffle = True, batch_size = batch_size, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset = val_data, shuffle = False, batch_size = batch_size, num_workers=8, pin_memory=True)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
train(model, epochs, device, train_loader, val_loader, optimizer, criterion, checkpoint_path=checkpoint_path, writer=writer, dct=flag_dct, sb=flag_sb)
if __name__ == '__main__':
cli()