forked from huawei-noah/multi_hyp_cc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hold_out.py
110 lines (90 loc) · 4.5 KB
/
hold_out.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 0-Clause License.
#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 0-Clause License for more details.
import argparse
import os
import random
import shutil
import time
import warnings
import sys
import json
import torch
import torch.backends.cudnn as cudnn
from core.worker import Worker
from core.utils import summary_angular_errors
from core.print_logger import PrintLogger
from core.cache_manager import CacheManager
parser = argparse.ArgumentParser(description='Color Constancy: Cross validation')
parser.add_argument('configurationfile',
help='path to configuration file')
parser.add_argument('dataset', help='dataset class name')
parser.add_argument('subdataset', help='subdataset name')
parser.add_argument('trainfiles', help='text file contraining the files to train')
parser.add_argument('--valfile', help='text file contraining the files to validate', type=str)
parser.add_argument('--testfile', help='text file contraining the files to test', type=str)
parser.add_argument('-gpu', type=int, help='GPU id to use.')
parser.add_argument('-j', '--workers', default=0, type=int,
help='number of data loading workers (default: 0)')
parser.add_argument('--resume', action='store_true',
help='resume from previous execution')
parser.add_argument('--pretrainedmodel', help='path to model pretrained model file')
parser.add_argument('-e', '--evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--save_fullres', action='store_true',
help='save full resolution prediction images')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--outputfolder', default='./output/', type=str,
help='path for the ouput folder. ')
parser.add_argument('--datapath', default='data/paths.json', type=str,
help='path to json file that specifies the directories of the datasets. ')
def generate_results(res, prefix=None):
errors = [r.error for r in res]
results = summary_angular_errors(errors)
if prefix is not None:
print(prefix, end=' ')
for k in results.keys():
print(k +':', "{:.4f}".format(results[k]), end=' ')
print()
def main():
args = parser.parse_args()
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
# load configuration file: epochs, loss function, etc... for this experiment
with open(args.configurationfile, 'r') as f:
conf = json.load(f)
# load datapath file: paths specific to the current machine
with open(args.datapath, 'r') as f:
data_conf = json.load(f)
# remove previous results
output_dir = os.path.join(args.outputfolder, args.dataset, args.subdataset, conf['name'])
if not args.evaluate and not args.resume:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# create output folder
os.makedirs(output_dir, exist_ok=True)
args.outputfolder = output_dir
# copy configuration file to output folder
shutil.copy(args.configurationfile, os.path.join(output_dir, os.path.basename(args.configurationfile)))
# we overwrite the stdout and stderr (standard output and error) to
# files in the output directory
sys.stdout = PrintLogger(os.path.join(output_dir, 'stdout.txt'), sys.stdout)
sys.stderr = PrintLogger(os.path.join(output_dir, 'stderr.txt'), sys.stderr)
fold = 0 # no folds, but we always use fold #0 for these experiments
cache = CacheManager(conf)
worker = Worker(fold, conf, data_conf, cache, args)
res, _ = worker.run()
# some datasets have no validation GT
if len(res) > 0:
# print angular errors statistics (mean, median, etc...)
generate_results(res, 'test')
if __name__ == '__main__':
main()