-
Notifications
You must be signed in to change notification settings - Fork 15
/
sorter.py
121 lines (94 loc) · 3.74 KB
/
sorter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from pq import *
from multiprocessing import cpu_count
import numba as nb
import math
import tqdm
@nb.jit
def arg_sort(distances):
top_k = min(131072, len(distances)-1)
indices = np.argpartition(distances, top_k)[:top_k]
return indices[np.argsort(distances[indices])]
@nb.jit
def product_arg_sort(q, compressed):
distances = np.dot(compressed, -q)
return arg_sort(distances)
@nb.jit
def angular_arg_sort(q, compressed, norms_sqr):
norm_q = np.linalg.norm(q)
distances = np.dot(compressed, q) / (norm_q * norms_sqr)
return arg_sort(distances)
@nb.jit
def euclidean_arg_sort(q, compressed):
distances = np.linalg.norm(q - compressed, axis=1)
return arg_sort(distances)
@nb.jit
def sign_arg_sort(q, compressed):
distances = np.empty(len(compressed), dtype=np.int32)
for i in range(len(compressed)):
distances[i] = np.count_nonzero((q > 0) != (compressed[i] > 0))
return arg_sort(distances)
@nb.jit
def euclidean_norm_arg_sort(q, compressed, norms_sqr):
distances = norms_sqr - 2.0 * np.dot(compressed, q)
return arg_sort(distances)
@nb.jit
def parallel_sort(metric, compressed, Q, X, norms_sqr=None):
"""
for each q in 'Q', sort the compressed items in 'compressed' by their distance,
where distance is determined by 'metric'
:param metric: euclid product
:param compressed: compressed items, same dimension as origin data, shape(N * D)
:param Q: queries, shape(len(Q) * D)
:return:
"""
rank = np.empty((Q.shape[0], min(131072, compressed.shape[0]-1)), dtype=np.int32)
p_range = tqdm.tqdm(nb.prange(Q.shape[0]))
if metric == 'product':
for i in p_range:
rank[i, :] = product_arg_sort(Q[i], compressed)
elif metric == 'angular':
if norms_sqr is None:
norms_sqr = np.linalg.norm(compressed, axis=1) ** 2
for i in p_range:
rank[i, :] = angular_arg_sort(Q[i], compressed, norms_sqr)
elif metric == 'euclid_norm':
if norms_sqr is None:
norms_sqr = np.linalg.norm(compressed, axis=1) ** 2
for i in p_range:
rank[i, :] = euclidean_norm_arg_sort(Q[i], compressed, norms_sqr)
else:
for i in p_range:
rank[i, :] = euclidean_arg_sort(Q[i], compressed)
return rank
@nb.jit
def true_positives(topK, Q, G, T):
result = np.empty(shape=(len(Q)))
for i in nb.prange(len(Q)):
result[i] = len(np.intersect1d(G[i], topK[i][:T]))
return result
class Sorter(object):
def __init__(self, compressed, Q, X, metric, norms_sqr=None):
self.Q = Q
self.X = X
self.topK = parallel_sort(metric, compressed, Q, X, norms_sqr=norms_sqr)
def recall(self, G, T):
t = min(T, len(self.topK[0]))
return t, self.sum_recall(G, T) / len(self.Q)
def sum_recall(self, G, T):
assert len(self.Q) == len(self.topK), "number of query not equals"
assert len(self.topK) <= len(G), "number of queries should not exceed the number of queries in ground truth"
true_positive = true_positives(self.topK, self.Q, G, T)
return np.sum(true_positive) / len(G[0]) # TP / K
class BatchSorter(object):
def __init__(self, compressed, Q, X, G, Ts, metric, batch_size, norms_sqr=None):
self.Q = Q
self.X = X
self.recalls = np.zeros(shape=(len(Ts)))
for i in range(math.ceil(len(Q) / float(batch_size))):
q = Q[i*batch_size: (i + 1) * batch_size, :]
g = G[i*batch_size: (i + 1) * batch_size, :]
sorter = Sorter(compressed, q, X, metric=metric, norms_sqr=norms_sqr)
self.recalls[:] = self.recalls[:] + [sorter.sum_recall(g, t) for t in Ts]
self.recalls = self.recalls / len(self.Q)
def recall(self):
return self.recalls