-
Notifications
You must be signed in to change notification settings - Fork 65
/
config.py
126 lines (105 loc) · 3.48 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Experiment Configuration"""
import os
import re
import glob
import itertools
import sacred
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
sacred.SETTINGS.CAPTURE_MODE = 'no'
ex = Experiment('PANet')
ex.captured_out_filter = apply_backspaces_and_linefeeds
source_folders = ['.', './dataloaders', './models', './util']
sources_to_save = list(itertools.chain.from_iterable(
[glob.glob(f'{folder}/*.py') for folder in source_folders]))
for source_file in sources_to_save:
ex.add_source_file(source_file)
@ex.config
def cfg():
"""Default configurations"""
input_size = (417, 417)
seed = 1234
cuda_visable = '0, 1, 2, 3, 4, 5, 6, 7'
gpu_id = 0
mode = 'test' # 'train' or 'test'
if mode == 'train':
dataset = 'VOC' # 'VOC' or 'COCO'
n_steps = 30000
label_sets = 0
batch_size = 1
lr_milestones = [10000, 20000, 30000]
align_loss_scaler = 1
ignore_label = 255
print_interval = 100
save_pred_every = 10000
model = {
'align': True,
}
task = {
'n_ways': 1,
'n_shots': 1,
'n_queries': 1,
}
optim = {
'lr': 1e-3,
'momentum': 0.9,
'weight_decay': 0.0005,
}
elif mode == 'test':
notrain = False
snapshot = './runs/PANet_VOC_sets_0_1way_1shot_[train]/1/snapshots/30000.pth'
n_runs = 5
n_steps = 1000
batch_size = 1
scribble_dilation = 0
bbox = False
scribble = False
# Set dataset config from the snapshot string
if 'VOC' in snapshot:
dataset = 'VOC'
elif 'COCO' in snapshot:
dataset = 'COCO'
else:
raise ValueError('Wrong snapshot name !')
# Set model config from the snapshot string
model = {}
for key in ['align',]:
model[key] = key in snapshot
# Set label_sets from the snapshot string
label_sets = int(snapshot.split('_sets_')[1][0])
# Set task config from the snapshot string
task = {
'n_ways': int(re.search("[0-9]+way", snapshot).group(0)[:-3]),
'n_shots': int(re.search("[0-9]+shot", snapshot).group(0)[:-4]),
'n_queries': 1,
}
else:
raise ValueError('Wrong configuration for "mode" !')
exp_str = '_'.join(
[dataset,]
+ [key for key, value in model.items() if value]
+ [f'sets_{label_sets}', f'{task["n_ways"]}way_{task["n_shots"]}shot_[{mode}]'])
path = {
'log_dir': './runs',
'init_path': './pretrained_model/vgg16-397923af.pth',
'VOC':{'data_dir': '../../data/Pascal/VOCdevkit/VOC2012/',
'data_split': 'trainaug',},
'COCO':{'data_dir': '../../data/COCO/',
'data_split': 'train',},
}
@ex.config_hook
def add_observer(config, command_name, logger):
"""A hook fucntion to add observer"""
exp_name = f'{ex.path}_{config["exp_str"]}'
if config['mode'] == 'test':
if config['notrain']:
exp_name += '_notrain'
if config['scribble']:
exp_name += '_scribble'
if config['bbox']:
exp_name += '_bbox'
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
ex.observers.append(observer)
return config