forked from deepinsight/insightface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rec_builder.py
135 lines (128 loc) · 4.83 KB
/
rec_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import pickle
import numpy as np
import os
import os.path as osp
import glob
import argparse
import cv2
import time
import datetime
import pickle
import sklearn
import mxnet as mx
from utils.utils_config import get_config
from dataset import FaceDataset, Rt26dof
class RecBuilder():
def __init__(self, path, image_size=(112, 112), is_train=True):
self.path = path
self.image_size = image_size
self.widx = 0
self.wlabel = 0
self.max_label = -1
#assert not osp.exists(path), '%s exists' % path
if is_train:
rec_file = osp.join(path, 'train.rec')
idx_file = osp.join(path, 'train.idx')
else:
rec_file = osp.join(path, 'val.rec')
idx_file = osp.join(path, 'val.idx')
#assert not osp.exists(rec_file), '%s exists' % rec_file
if not osp.exists(path):
os.makedirs(path)
self.writer = mx.recordio.MXIndexedRecordIO(idx_file,
rec_file,
'w')
self.meta = []
def add(self, imgs):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
assert len(imgs) > 0
label = self.wlabel
for img in imgs:
idx = self.widx
image_meta = {'image_index': idx, 'image_classes': [label]}
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = label
self.wlabel += 1
return label
def add_image(self, img, label):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
idx = self.widx
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=100,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.widx += 1
def close(self):
print('stat:', self.widx, self.wlabel)
if __name__ == "__main__":
#cfg = get_config('configs/s1.py')
cfg = get_config('configs/s2.py')
cfg.task = 0
cfg.input_size = 512
for is_train in [True, False]:
dataset = FaceDataset(cfg, is_train=is_train, local_rank=0)
dataset.transform = None
writer = RecBuilder(cfg.cache_dir, is_train=is_train)
#writer = RecBuilder("temp", is_train=is_train)
print('total:', len(dataset))
#meta = np.zeros( (len(dataset), 3), dtype=np.float32 )
meta = []
subset_name = 'train' if is_train else 'val'
meta_path = osp.join(cfg.cache_dir, '%s.meta'%subset_name)
eye_missing = 0
for idx in range(len(dataset)):
#img_local, img_global, label_verts, label_Rt, tform = dataset[idx]
#img_local, label_verts, label_Rt, tform = dataset[idx]
data = dataset[idx]
img_local = data['img_local']
label_verts = data['verts']
label_Rt = data['rt']
tform = data['tform']
label_verts = label_verts.numpy()
label_Rt = label_Rt.numpy()
tform = tform.numpy()
label_6dof = Rt26dof(label_Rt, True)
pose = label_6dof[:3]
#print(image.shape, label_verts.shape, label_6dof.shape)
#print(image.__class__, label_verts.__class__)
img_local = img_local[:,:,::-1]
#img_global = img_global[:,:,::-1]
#image = np.concatenate( (img_local, img_global), axis=1 )
image = img_local
label = list(label_verts.flatten()) + list(label_Rt.flatten()) + list(tform.flatten())
expect_len = 1220*3+16+6
if 'eye_world_left' in data:
if idx==0:
print('find eye')
eyel = data['eye_world_left'].numpy()
eyer = data['eye_world_right'].numpy()
label += list(eyel.flatten())
label += list(eyer.flatten())
expect_len += 481*6
else:
eye_missing += 1
continue
meta.append(pose)
assert len(label)==expect_len
writer.add_image(image, label)
if idx%100==0:
print('processing:', idx, image.shape, len(label))
if idx<10:
cv2.imwrite("temp/%d.jpg"%idx, image)
writer.close()
meta = np.array(meta, dtype=np.float32)
np.save(meta_path, meta)
print('Eye missing:', eye_missing, is_train)