-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_test_only.py
109 lines (86 loc) · 3.63 KB
/
main_test_only.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os
import torch
import PIL
import torchvision
import ast
from omegaconf import OmegaConf
from utils import setup_logger, get_current_time
from baseline_trainer import trainer_init
from methods.wdiff import WDiff
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def str2none(v):
if v is None:
return v
elif isinstance(v, int):
return v
elif v.lower() in ("none"):
return None
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def parse_opts(opt):
"""
Convert string arguments to their appropriate types using ast.literal_eval
"""
return ast.literal_eval(opt)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Implementation of W-Diff')
parser.add_argument('--cfg', default='./configs/eval_fix/cfg_yearbook.yaml', metavar='FILE', help='path to config file', type=str)
parser.add_argument('--model_path', default='./checkpoints/moons/', metavar='FILE', help='path to model checkpoints', type=str)
parser.add_argument("opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
configs = OmegaConf.load(args.cfg)
opts_dict = {args.opts[i]: args.opts[i + 1] for i in range(0, len(args.opts), 2)}
configs_from_opts = OmegaConf.create()
for key, value in opts_dict.items():
keys = key.split('.')
current_level = configs_from_opts
for k in keys[:-1]:
if k not in current_level:
current_level[k] = OmegaConf.create()
current_level = current_level[k]
if keys[-1] in ['data_dir', 'backbone']:
current_level[keys[-1]] = value
else:
if value.lower() in ("none"):
current_level[keys[-1]] = None
else:
current_level[keys[-1]] = ast.literal_eval(value)
cfg = OmegaConf.merge(configs, configs_from_opts)
cfg.trainer.dim_bottleneck_f = str2none(cfg.trainer.dim_bottleneck_f)
if not os.path.isdir(cfg.log.log_dir):
os.makedirs(cfg.log.log_dir)
logger = setup_logger("main_test_only", cfg.log.log_dir, 0, filename=get_current_time() + "_" + cfg.log.log_name)
logger.info("PTL.version = {}".format(PIL.__version__))
logger.info("torch.version = {}".format(torch.__version__))
logger.info("torchvision.version = {}".format(torchvision.__version__))
logger.info("Running with config:\n{}".format(cfg))
dataset, criterion, network, diffusion_model, optimizer, scheduler = trainer_init(cfg)
if cfg.trainer.method == "wdiff":
trainer = WDiff(cfg, logger, dataset, network, diffusion_model, criterion, optimizer, scheduler)
else:
raise ValueError
print("loading model checkpoints from {}".format(args.model_path))
model_infos = torch.load(args.model_path)
trainer.network = trainer.network.cpu()
trainer.diffusion_model = trainer.diffusion_model.cpu()
trainer.diffusion_model.load_state_dict(model_infos['diffusion_model'])
trainer.network.enc.load_state_dict(model_infos['enc'])
for item in model_infos['reference_point_queue']:
trainer.network.reference_point_queue.put_item(item)
trainer.network.cuda()
trainer.diffusion_model.cuda()
torch.cuda.empty_cache()
trainer.evaluate_offline()