-
Notifications
You must be signed in to change notification settings - Fork 6
/
color_similarity_inference.py
55 lines (49 loc) · 2 KB
/
color_similarity_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import h5py
import subprocess
import knn_cnn_features
import kmeans_gpu
import glob
def extract_vid2img_from_vid(vid_path):
frame_list = glob.glob('data/tmp/'+vid_path.split('/')[-1].split('.')[0]+'/*.jpg')
vid_img = None
for fr in frame_list:
_, bar_img = kmeans_gpu.run(fr,clusters=10)
if type(vid_img) is not np.ndarray:
vid_img = bar_img
else:
vid_img = np.concatenate((vid_img, bar_img), axis=0)
return vid_img
def load_color_data_ucf():
h5f = h5py.File('color_UCF_vid2img.h5', 'r')
color_vid2imgs = np.array(h5f['color_vid2imgs']).astype('float32')
color_labels = np.array([fl.decode() for fl in h5f['vid_labels']])
h5f.close()
return color_vid2imgs, color_labels
def get_ordered_unique(listed):
seen = set()
seen_add = seen.add
ordered_listed = [x for x in listed if not (x in seen or seen_add(x))]
return ordered_listed
def similar_color_ucf_video(vid_path, k=10, dist=False, verbose=False, newVid=False):
if newVid:
vid_feature_vector = extract_vid2img_from_vid(vid_path)
vid_feature_vector = vid_feature_vector.flatten()[np.newaxis,].astype('float32')
else:
vid_feature_vector = color_vid2imgs[np.where(color_labels == np.str_(vid_path.split('/')[-1].split(".")[-2]))[0]]
distances, feature_indices = knn_cnn_features.run_knn_features(\
color_vid2imgs, test_vectors=vid_feature_vector,flat=True,k=k, dist=True)
del vid_feature_vector
if verbose:
print(color_labels[feature_indices][0])
if dist:
return list(distances[0]), list(map(str, color_labels[feature_indices][0]))
else:
return list(color_labels[feature_indices][0])
color_vid2imgs, color_labels = load_color_data_ucf()
# import time
# start = time.time()
# for i in range(5):
# similar_color_ucf_video('data/UCF101/v_ApplyEyeMakeup_g01_c01.mp4', verbose=True, newVid=True)
# print((time.time()-start)/5)
# 2.293074941635132 seconds (0.9 seconds if not new video)