forked from dragen1860/MAML-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
omniglot_train.py
95 lines (75 loc) · 3.63 KB
/
omniglot_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch, os
import numpy as np
from omniglotNShot import OmniglotNShot
import argparse
from meta import Meta
def main(args):
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
print(args)
config = [
('conv2d', [64, 1, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 2, 2, 1, 0]),
('relu', [True]),
('bn', [64]),
('flatten', []),
('linear', [args.n_way, 64])
]
device = torch.device('cuda')
maml = Meta(args, config).to(device)
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print(maml)
print('Total trainable tensors:', num)
db_train = OmniglotNShot('omniglot',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=args.imgsz)
for step in range(args.epoch):
x_spt, y_spt, x_qry, y_qry = db_train.next()
x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
# set traning=True to update running_mean, running_variance, bn_weights, bn_bias
accs = maml(x_spt, y_spt, x_qry, y_qry)
if step % 50 == 0:
print('step:', step, '\ttraining acc:', accs)
if step % 500 == 0:
accs = []
for _ in range(1000//args.task_num):
# test
x_spt, y_spt, x_qry, y_qry = db_train.next('test')
x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
# split to single task each time
for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
accs.append( test_acc )
# [b, update_step+1]
accs = np.array(accs).mean(axis=0).astype(np.float16)
print('Test acc:', accs)
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--epoch', type=int, help='epoch number', default=40000)
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument('--imgsz', type=int, help='imgsz', default=28)
argparser.add_argument('--imgc', type=int, help='imgc', default=1)
argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32)
argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.4)
argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
args = argparser.parse_args()
main(args)