-
Notifications
You must be signed in to change notification settings - Fork 522
/
KdTreePointQuery.java
122 lines (108 loc) · 3.62 KB
/
KdTreePointQuery.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package structures;
import java.util.Random;
public class KdTreePointQuery {
int[] x;
int[] y;
public KdTreePointQuery(int[] x, int[] y) {
this.x = x;
this.y = y;
build(0, x.length, true);
}
void build(int low, int high, boolean divX) {
if (high - low <= 1)
return;
int mid = (low + high) >>> 1;
nth_element(low, high, mid, divX);
build(low, mid, !divX);
build(mid + 1, high, !divX);
}
static final Random rnd = new Random(1);
// See http://www.cplusplus.com/reference/algorithm/nth_element
void nth_element(int low, int high, int n, boolean divX) {
while (true) {
int k = partition(low, high, low + rnd.nextInt(high - low), divX);
if (n < k)
high = k;
else if (n > k)
low = k + 1;
else
return;
}
}
int partition(int fromInclusive, int toExclusive, int separatorIndex, boolean divX) {
int i = fromInclusive;
int j = toExclusive - 1;
if (i >= j)
return j;
int separator = divX ? x[separatorIndex] : y[separatorIndex];
swap(i++, separatorIndex);
while (i <= j) {
while (i <= j && (divX ? x[i] : y[i]) < separator) ++i;
while (i <= j && (divX ? x[j] : y[j]) > separator) --j;
if (i >= j)
break;
swap(i++, j--);
}
swap(j, fromInclusive);
return j;
}
void swap(int i, int j) {
int t = x[i];
x[i] = x[j];
x[j] = t;
t = y[i];
y[i] = y[j];
y[j] = t;
}
long bestDist;
int bestNode;
public int findNearestNeighbour(int px, int py) {
bestDist = Long.MAX_VALUE;
findNearestNeighbour(0, x.length, px, py, true);
return bestNode;
}
void findNearestNeighbour(int low, int high, int px, int py, boolean divX) {
if (high - low <= 0)
return;
int mid = (low + high) >>> 1;
long dx = px - x[mid];
long dy = py - y[mid];
long dist = dx * dx + dy * dy;
if (bestDist > dist) {
bestDist = dist;
bestNode = mid;
}
long delta = divX ? dx : dy;
long delta2 = delta * delta;
if (delta <= 0) {
findNearestNeighbour(low, mid, px, py, !divX);
if (delta2 < bestDist)
findNearestNeighbour(mid + 1, high, px, py, !divX);
} else {
findNearestNeighbour(mid + 1, high, px, py, !divX);
if (delta2 < bestDist)
findNearestNeighbour(low, mid, px, py, !divX);
}
}
// random test
public static void main(String[] args) {
for (int step = 0; step < 100_000; step++) {
int qx = rnd.nextInt(100) - 50;
int qy = rnd.nextInt(100) - 50;
int n = rnd.nextInt(100) + 1;
int[] x = new int[n];
int[] y = new int[n];
long minDist = Long.MAX_VALUE;
for (int i = 0; i < n; i++) {
x[i] = rnd.nextInt(100) - 50;
y[i] = rnd.nextInt(100) - 50;
minDist = Math.min(minDist, (long) (x[i] - qx) * (x[i] - qx) + (long) (y[i] - qy) * (y[i] - qy));
}
KdTreePointQuery kdTree = new KdTreePointQuery(x, y);
int index = kdTree.findNearestNeighbour(qx, qy);
if (minDist != kdTree.bestDist
|| (long) (x[index] - qx) * (x[index] - qx) + (long) (y[index] - qy) * (y[index] - qy) != minDist)
throw new RuntimeException();
}
}
}