forked from dragen1860/MAML-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
miniimagenet_train.py
111 lines (86 loc) · 4.19 KB
/
miniimagenet_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch, os
import numpy as np
from MiniImagenet import MiniImagenet
import scipy.stats
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import random, sys, pickle
import argparse
from meta import Meta
def mean_confidence_interval(accs, confidence=0.95):
n = accs.shape[0]
m, se = np.mean(accs), scipy.stats.sem(accs)
h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
return m, h
def main():
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
print(args)
config = [
('conv2d', [32, 3, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 1, 0]),
('flatten', []),
('linear', [args.n_way, 32 * 5 * 5])
]
device = torch.device('cuda')
maml = Meta(args, config).to(device)
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print(maml)
print('Total trainable tensors:', num)
# batchsz here means total episode number
mini = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=10000, resize=args.imgsz)
mini_test = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=100, resize=args.imgsz)
for epoch in range(args.epoch//10000):
# fetch meta_batchsz num of episode each time
db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
accs = maml(x_spt, y_spt, x_qry, y_qry)
if step % 30 == 0:
print('step:', step, '\ttraining acc:', accs)
if step % 500 == 0: # evaluation
db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
accs_all_test = []
for x_spt, y_spt, x_qry, y_qry in db_test:
x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
accs_all_test.append(accs)
# [b, update_step+1]
accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
print('Test acc:', accs)
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)
argparser.add_argument('--imgc', type=int, help='imgc', default=3)
argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)
argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
args = argparser.parse_args()
main()