-
Notifications
You must be signed in to change notification settings - Fork 6
/
h36.py
160 lines (138 loc) · 5.85 KB
/
h36.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Only for mean teacher
"""
from __future__ import division
import cv2
import time
import random
import numpy as np
from os.path import join
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
import config
import constants
from utils.imutils import crop, flip_img, flip_pose, flip_kp, transform, rot_aa
class H36M(Dataset):
def __init__(self, options, dataset):
super(H36M, self).__init__()
self.dataset = dataset
self.options = options
self.img_dir = config.H36M_ROOT
self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
self.data = np.load(config.DATASET_FILES[0][dataset])
# load attributes
self.imgname = self.data['imgname']
self.scale = self.data['scale']
self.center = self.data['center']
self.pose = self.data['pose'].astype(np.float)
self.betas = self.data['shape'].astype(np.float)
self.pose_3d = self.data['S']
keypoints_gt = self.data['part']
keypoints_openpose = np.zeros((keypoints_gt.shape[0], 25, 3))
print(keypoints_gt.shape, keypoints_openpose.shape)
self.keypoints = np.concatenate([keypoints_openpose, keypoints_gt], axis=1)
self.length = self.scale.shape[0]
def augm_params(self, istrain):
flip = 0 # flipping
pn = np.ones(3) # per channel pixel-noise
rot = 0 # rotation
sc = 1 # scaling
if istrain:
# We flip with probability 1/2
if np.random.uniform() <= 0.5:
flip = 1
sc = min(1+self.options.scale_factor,
max(1-self.options.scale_factor, np.random.randn()*self.options.scale_factor+1))
return flip, pn, rot, sc
def read_image(self, imgname):
img = cv2.imread(imgname)[:,:,::-1].copy().astype(np.float32)
return img
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn, is_train):
rgb_img = crop(rgb_img.copy(), center, scale, [constants.IMG_RES, constants.IMG_RES], rot=rot)
if is_train:
if flip:
rgb_img = flip_img(rgb_img)
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
return rgb_img
def j2d_processing(self, kp, center, scale, r, f, is_train):
"""Process gt 2D keypoints and apply all augmentation transforms."""
nparts = kp.shape[0]
for i in range(nparts):
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
[constants.IMG_RES, constants.IMG_RES], rot=r)
# convert to normalized coordinates
kp[:,:-1] = 2.*kp[:,:-1]/constants.IMG_RES - 1.
# flip the x coordinates
if is_train and f:
kp = flip_kp(kp)
kp = kp.astype('float32')
return kp
def j3d_processing(self, S, r, f, is_train):
"""Process gt 3D keypoints and apply all augmentation transforms."""
# in-plane rotation
rot_mat = np.eye(3)
if not r == 0:
rot_rad = -r * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
# flip the x coordinates
if is_train and f:
S = flip_kp(S)
S = S.astype('float32')
return S
def pose_processing(self, pose, r, f, is_train):
"""Process SMPL theta parameters and apply all augmentation transforms."""
if is_train:
# rotation or the pose parameters
pose[:3] = rot_aa(pose[:3], r)
# flip the pose parameters
if f:
pose = flip_pose(pose)
# (72),float
pose = pose.astype('float32')
return pose
def process_sample(self, image, pose, beta, keypoints, S, center, scale, flip, pn, rot, sc, is_train):
# labeled keypoints
kp2d = torch.from_numpy(self.j2d_processing(keypoints, center, sc*scale, rot, flip, is_train=is_train)).float()
img = self.rgb_processing(image, center, sc*scale, rot, flip, pn, is_train=is_train)
img = torch.from_numpy(img).float()
img = self.normalize_img(img)
pose = torch.from_numpy(self.pose_processing(pose, rot, flip, is_train=is_train)).float()
betas = torch.from_numpy(beta).float()
S = torch.from_numpy(self.j3d_processing(S, rot, flip, is_train=is_train)).float()
return kp2d, img, pose, betas, S
def __getitem__(self, index):
item = {}
kp2d_all, img_all, pose_all, betas_all, pose_3d_all = [], [], [], [], []
scale = self.scale[index].copy()
center = self.center[index].copy()
keypoints = self.keypoints[index].copy()
pose = self.pose[index].copy()
betas = self.betas[index].copy()
imgname = join(self.img_dir, self.imgname[index])
img = self.read_image(imgname).copy()
S = self.pose_3d[index].copy()
# ori image, no aug
flip, pn, rot, sc = 0, np.ones(3), 0, 1
kp2d_i, img_i, pose_i, betas_i, S_i = self.process_sample(img.copy(), pose, betas, keypoints, S, center, scale, flip, pn, rot, sc, is_train=False)
item['oriimg'] = img_i
item['oripose'] = pose_i
item['oribeta'] = betas_i
item['orikeypoints'] = kp2d_i
item['oripose_3d'] = S_i
kp2d_all.append(kp2d_i)
img_all.append(img_i)
pose_all.append(pose_i)
betas_all.append(betas_i)
pose_3d_all.append(S_i)
item['keypoints'] = torch.stack(kp2d_all)
item['img'] = torch.stack(img_all)
item['pose'] = torch.stack(pose_all)
item['betas'] = torch.stack(betas_all)
item['pose_3d'] = torch.stack(pose_3d_all)
return item
def __len__(self):
return len(self.imgname)