forked from yanx27/EverybodyDanceNow_reproduce_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_target.py
107 lines (90 loc) · 3.26 KB
/
make_target.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
openpose_dir = Path('./src/PoseEstimation/')
save_dir = Path('./data/target/')
save_dir.mkdir(exist_ok=True)
img_dir = save_dir.joinpath('images')
img_dir.mkdir(exist_ok=True)
if len(os.listdir('./data/target/images'))<100:
cap = cv2.VideoCapture(str(save_dir.joinpath('mv.mp4')))
i = 0
while (cap.isOpened()):
flag, frame = cap.read()
if flag == False :
break
cv2.imwrite(str(img_dir.joinpath('{:05}.png'.format(i))), frame)
if i%100 == 0:
print('Has generated %d picetures'%i)
i += 1
import sys
sys.path.append(str(openpose_dir))
sys.path.append('./src/utils')
# openpose
from network.rtpose_vgg import get_model
from evaluate.coco_eval import get_multiplier, get_outputs
# utils
from openpose_utils import remove_noise, get_pose
weight_name = './src/PoseEstimation/network/weight/pose_model.pth'
print('load model...')
model = get_model('vgg19')
model.load_state_dict(torch.load(weight_name))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()
pass
save_dir = Path('./data/target/')
save_dir.mkdir(exist_ok=True)
img_dir = save_dir.joinpath('images')
img_dir.mkdir(exist_ok=True)
'''make label images for pix2pix'''
train_dir = save_dir.joinpath('train')
train_dir.mkdir(exist_ok=True)
train_img_dir = train_dir.joinpath('train_img')
train_img_dir.mkdir(exist_ok=True)
train_label_dir = train_dir.joinpath('train_label')
train_label_dir.mkdir(exist_ok=True)
train_head_dir = train_dir.joinpath('head_img')
train_head_dir.mkdir(exist_ok=True)
pose_cords = []
for idx in tqdm(range(len(os.listdir(str(img_dir))))):
img_path = img_dir.joinpath('{:05}.png'.format(idx))
img = cv2.imread(str(img_path))
shape_dst = np.min(img.shape[:2])
oh = (img.shape[0] - shape_dst) // 2
ow = (img.shape[1] - shape_dst) // 2
img = img[oh:oh + shape_dst, ow:ow + shape_dst]
img = cv2.resize(img, (512, 512))
multiplier = get_multiplier(img)
with torch.no_grad():
paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
r_heatmap = np.array([remove_noise(ht)
for ht in heatmap.transpose(2, 0, 1)[:-1]]).transpose(1, 2, 0)
heatmap[:, :, :-1] = r_heatmap
param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
#TODO get_pose
label, cord = get_pose(param, heatmap, paf)
index = 13
crop_size = 25
try:
head_cord = cord[index]
except:
head_cord = pose_cords[-1] # if there is not head point in picture, use last frame
pose_cords.append(head_cord)
head = img[int(head_cord[1] - crop_size): int(head_cord[1] + crop_size),
int(head_cord[0] - crop_size): int(head_cord[0] + crop_size), :]
plt.imshow(head)
plt.savefig(str(train_head_dir.joinpath('pose_{}.jpg'.format(idx))))
plt.clf()
cv2.imwrite(str(train_img_dir.joinpath('{:05}.png'.format(idx))), img)
cv2.imwrite(str(train_label_dir.joinpath('{:05}.png'.format(idx))), label)
pose_cords = np.array(pose_cords, dtype=np.int)
np.save(str((save_dir.joinpath('pose.npy'))), pose_cords)
torch.cuda.empty_cache()