-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_model.py
73 lines (57 loc) · 2.6 KB
/
evaluate_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from PIL import Image
from model import Encoder, Decoder, Decoder3
from dataset.data_generator import Data_Generator, val_augmentations
from torch import nn
from utils import Metrics
from training import evaluate
import click
@click.group()
def cli():
pass
@cli.command()
@click.option('--experiment_name', type=str, default='', show_default=True)
@click.option('--model_name', required=True, type=str, show_default=True)
@click.option('--batch_size', required=True, type=int, default=16, show_default=True)
@click.option('--gpu_id', type=str, default='0', show_default=True)
@click.option('--checkpoint', type=str, show_default=True)
@click.option('--dataset_name', required=True, show_default=True)
def evaluating(experiment_name, model_name, batch_size, gpu_id, checkpoint, dataset_name):
kwargs = locals()
print(kwargs)
experiment_name= experiment_name
if gpu_id is not None:
device = torch.device('cuda:'+str(gpu_id) if torch.cuda.is_available() else 'cpu')
else:
device = torch.device('cpu')
if model_name == 'dct':
flag_dct = True
flag_sb = False
model = Decoder()
if model_name == 'sb':
flag_dct = False
flag_sb = True
model = Decoder()
if model_name == 'dct_sb':
flag_dct = True
flag_sb = True
model = Decoder3()
if dataset_name == 'casia1':
eval_data = Data_Generator('/home/siopi/drive2/CASIA2/CASIA 1.0 dataset/', '/home/siopi/drive2/CASIA2/CASIA 1.0 dataset/casia1.csv',transform=val_augmentations, split=['train', 'val'], dct=flag_dct, sb=flag_sb)
if dataset_name == 'ifs_tc':
eval_data = Data_Generator('/home/siopi/drive2/IFS-TC/', '/home/siopi/drive2/IFS-TC/ifstc.csv', transform=val_augmentations, split=['test'], dct=flag_dct, sb=flag_sb, inverse=True)
if dataset_name == 'columbia':
eval_data = Data_Generator('/home/siopi/drive2/Columbia/', '/home/siopi/drive2/Columbia/columbia_dataset.csv', transform=val_augmentations, split=['test'], dct=flag_dct, sb=flag_sb)
eval_loader = DataLoader(dataset = eval_data, shuffle = False, batch_size = batch_size, num_workers=16, pin_memory=True)
checkpoint = torch.load(checkpoint, map_location='cpu')
model = model.to(device)
criterion = nn.BCELoss().to(device)
val_metrics = Metrics()
val_iterator = iter(eval_loader)
val_metrics = evaluate(model, criterion, val_iterator, device, val_metrics, flag_dct, flag_sb)
print(val_metrics)
if __name__ == '__main__':
cli()