-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from bjorn-martinsson/convex_hull
Convex hull trick
- Loading branch information
Showing
3 changed files
with
104 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import division | ||
|
||
def convex_hull_trick(K, M, integer = True): | ||
""" | ||
Given lines on the form y = K[i] * x + M[i] this function returns intervals, | ||
such that on each interval the convex hull is made up of a single line. | ||
Input: | ||
K: list of the slopes | ||
M: list of the constants (value at x = 0) | ||
interger: boolean for turning on / off integer mode. Integer mode is exact, it | ||
works by effectively flooring the seperators of the intervals. | ||
Return: | ||
hull_i: on interval j, line i = hull_i[j] is >= all other lines | ||
hull_x: interval j and j + 1 is separated by x = hull_x[j], (hull_x[j] is the last x in interval j) | ||
""" | ||
if integer: | ||
intersect = lambda i,j: (M[j] - M[i]) // (K[i] - K[j]) | ||
else: | ||
intersect = lambda i,j: (M[j] - M[i]) / (K[i] - K[j]) | ||
|
||
# assert len(K) == len(M) | ||
|
||
hull_i = [] | ||
hull_x = [] | ||
order = sorted(range(len(K)), key = K.__getitem__) | ||
for i in order: | ||
while True: | ||
if not hull_i: | ||
hull_i.append(i) | ||
break | ||
elif K[hull_i[-1]] == K[i]: | ||
if M[hull_i[-1]] >= M[i]: | ||
break | ||
hull_i.pop() | ||
if hull_x: hull_x.pop() | ||
else: | ||
x = intersect(i, hull_i[-1]) | ||
if hull_x and x <= hull_x[-1]: | ||
hull_i.pop() | ||
hull_x.pop() | ||
else: | ||
hull_i.append(i) | ||
hull_x.append(x) | ||
break | ||
return hull_i, hull_x | ||
|
||
from bisect import bisect_left | ||
def max_query(x, K, M, hull_i, hull_x): | ||
""" Find maximum value at x in O(log n) time """ | ||
i = hull_i[bisect_left(hull_x, x)] | ||
return K[i] * x + M[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from pyrival.data_structures import convex_hull_trick, max_query | ||
|
||
def brute(K, M, X): | ||
assert(len(K) == len(M)) | ||
|
||
out = [] | ||
for x in X: | ||
y = K[0] * x + M[0] | ||
for i in range(1, len(K)): | ||
y = max(y, K[i] * x + M[i]) | ||
out.append(y) | ||
return out | ||
|
||
|
||
def test_convex_line_hull_integral(t = 50000): | ||
import random | ||
random.seed(1337) | ||
|
||
for _ in range(t): | ||
n = random.randint(1, 20) | ||
K = [random.randint(-10, 10) for _ in range(n)] | ||
M = [random.randint(-10, 10) for _ in range(n)] | ||
X = list(range(-10, 10 + 1)) | ||
|
||
brute_ans = brute(K, M, X) | ||
|
||
hull_i, hull_x = convex_hull_trick(K, M) | ||
assert(len(hull_i) - 1 == len(hull_x)) | ||
|
||
ans = [max_query(x, K, M, hull_i, hull_x) for x in X] | ||
assert(ans == brute_ans) | ||
|
||
def test_convex_line_hull_float(t = 50000): | ||
import random | ||
random.seed(1337) | ||
|
||
for _ in range(t): | ||
n = random.randint(1, 20) | ||
K = [random.randint(-10, 10) for _ in range(n)] | ||
M = [random.randint(-10, 10) for _ in range(n)] | ||
X = [random.uniform(-10, 10) for _ in range(21)] | ||
|
||
brute_ans = brute(K, M, X) | ||
|
||
hull_i, hull_x = convex_hull_trick(K, M, integer = False) | ||
assert(len(hull_i) - 1 == len(hull_x)) | ||
|
||
ans = [max_query(x, K, M, hull_i, hull_x) for x in X] | ||
assert(len(ans) == len(brute_ans)) | ||
assert(all(abs(x - y) <= 1e-9 for x,y in zip(ans, brute_ans))) |