forked from yanx27/EverybodyDanceNow_reproduce_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_source.py
107 lines (90 loc) · 3.41 KB
/
make_source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
'''Download and extract video'''
import cv2
from pathlib import Path
import os
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
torch.multiprocessing.set_sharing_strategy('file_system')
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(0)
save_dir = Path('./data/source/')
save_dir.mkdir(exist_ok=True)
img_dir = save_dir.joinpath('images')
img_dir.mkdir(exist_ok=True)
if len(os.listdir('./data/source/images'))<100:
cap = cv2.VideoCapture(str(save_dir.joinpath('mv.mp4')))
i = 0
while (cap.isOpened()):
flag, frame = cap.read()
if flag == False or i >= 1000:
break
cv2.imwrite(str(img_dir.joinpath('{:05}.png'.format(i))), frame)
if i%100 == 0:
print('Has generated %d picetures'%i)
i += 1
'''Pose estimation (OpenPose)'''
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
openpose_dir = Path('./src/PoseEstimation/')
import sys
sys.path.append(str(openpose_dir))
sys.path.append('./src/utils')
# openpose
#from network.rtpose_vgg import gopenpose_diret_model
from evaluate.coco_eval import get_multiplier, get_outputs
from network.rtpose_vgg import get_model
# utils
from openpose_utils import remove_noise, get_pose
weight_name = './src/PoseEstimation/network/weight/pose_model.pth'
model = get_model('vgg19')
model.load_state_dict(torch.load(weight_name))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()
'''make label images for pix2pix'''
test_img_dir = save_dir.joinpath('test_img')
test_img_dir.mkdir(exist_ok=True)
test_label_dir = save_dir.joinpath('test_label_ori')
test_label_dir.mkdir(exist_ok=True)
test_head_dir = save_dir.joinpath('test_head_ori')
test_head_dir.mkdir(exist_ok=True)
pose_cords = []
for idx in tqdm(range(len(os.listdir(str(img_dir))))):
img_path = img_dir.joinpath('{:05}.png'.format(idx))
img = cv2.imread(str(img_path))
shape_dst = np.min(img.shape[:2])
oh = (img.shape[0] - shape_dst) // 2
ow = (img.shape[1] - shape_dst) // 2
img = img[oh:oh + shape_dst, ow:ow + shape_dst]
img = cv2.resize(img, (512, 512))
multiplier = get_multiplier(img)
with torch.no_grad():
paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
r_heatmap = np.array([remove_noise(ht)
for ht in heatmap.transpose(2, 0, 1)[:-1]]) \
.transpose(1, 2, 0)
heatmap[:, :, :-1] = r_heatmap
param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
label, cord = get_pose(param, heatmap, paf)
index = 13
crop_size = 25
try:
head_cord = cord[index]
except:
head_cord = pose_cords[-1] # if there is not head point in picture, use last frame
pose_cords.append(head_cord)
head = img[int(head_cord[1] - crop_size): int(head_cord[1] + crop_size),
int(head_cord[0] - crop_size): int(head_cord[0] + crop_size), :]
plt.imshow(head)
plt.savefig(str(test_head_dir.joinpath('pose_{}.jpg'.format(idx))))
plt.clf()
cv2.imwrite(str(test_img_dir.joinpath('{:05}.png'.format(idx))), img)
cv2.imwrite(str(test_label_dir.joinpath('{:05}.png'.format(idx))), label)
if idx % 100 == 0 and idx != 0:
pose_cords_arr = np.array(pose_cords, dtype=np.int)
np.save(str((save_dir.joinpath('pose_source.npy'))), pose_cords_arr)
pose_cords_arr = np.array(pose_cords, dtype=np.int)
np.save(str((save_dir.joinpath('pose_source.npy'))), pose_cords_arr)
torch.cuda.empty_cache()