Skip to content

Commit

Permalink
Merge pull request #38 from bjorn-martinsson/convex_hull
Browse files Browse the repository at this point in the history
Convex hull trick
  • Loading branch information
cheran-senthil authored Jun 12, 2020
2 parents 1676dff + 74c7e07 commit f4f7a3b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyrival/data_structures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .BitArray import BitArray
from .CFraction import CFrac2Frac, CFraction
from .convex_hull_trick import convex_hull_trick, max_query
from .DisjointSetUnion import DisjointSetUnion, UnionFind
from .FenwickTree import FenwickTree
from .Fraction import Fraction, limit_denominator
Expand All @@ -22,6 +23,7 @@
"BitArray",
"CFrac2Frac",
"CFraction",
"convex_hull_trick",
"DisjointSetUnion",
"UnionFind",
"FenwickTree",
Expand All @@ -35,6 +37,7 @@
"LinkedList",
"Node",
"create",
"max_query",
"minimum",
"setter",
"RangeQuery",
Expand Down
51 changes: 51 additions & 0 deletions pyrival/data_structures/convex_hull_trick.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import division

def convex_hull_trick(K, M, integer = True):
"""
Given lines on the form y = K[i] * x + M[i] this function returns intervals,
such that on each interval the convex hull is made up of a single line.
Input:
K: list of the slopes
M: list of the constants (value at x = 0)
interger: boolean for turning on / off integer mode. Integer mode is exact, it
works by effectively flooring the seperators of the intervals.
Return:
hull_i: on interval j, line i = hull_i[j] is >= all other lines
hull_x: interval j and j + 1 is separated by x = hull_x[j], (hull_x[j] is the last x in interval j)
"""
if integer:
intersect = lambda i,j: (M[j] - M[i]) // (K[i] - K[j])
else:
intersect = lambda i,j: (M[j] - M[i]) / (K[i] - K[j])

# assert len(K) == len(M)

hull_i = []
hull_x = []
order = sorted(range(len(K)), key = K.__getitem__)
for i in order:
while True:
if not hull_i:
hull_i.append(i)
break
elif K[hull_i[-1]] == K[i]:
if M[hull_i[-1]] >= M[i]:
break
hull_i.pop()
if hull_x: hull_x.pop()
else:
x = intersect(i, hull_i[-1])
if hull_x and x <= hull_x[-1]:
hull_i.pop()
hull_x.pop()
else:
hull_i.append(i)
hull_x.append(x)
break
return hull_i, hull_x

from bisect import bisect_left
def max_query(x, K, M, hull_i, hull_x):
""" Find maximum value at x in O(log n) time """
i = hull_i[bisect_left(hull_x, x)]
return K[i] * x + M[i]
50 changes: 50 additions & 0 deletions tests/data_structures/test_convex_hull_trick.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pyrival.data_structures import convex_hull_trick, max_query

def brute(K, M, X):
assert(len(K) == len(M))

out = []
for x in X:
y = K[0] * x + M[0]
for i in range(1, len(K)):
y = max(y, K[i] * x + M[i])
out.append(y)
return out


def test_convex_line_hull_integral(t = 50000):
import random
random.seed(1337)

for _ in range(t):
n = random.randint(1, 20)
K = [random.randint(-10, 10) for _ in range(n)]
M = [random.randint(-10, 10) for _ in range(n)]
X = list(range(-10, 10 + 1))

brute_ans = brute(K, M, X)

hull_i, hull_x = convex_hull_trick(K, M)
assert(len(hull_i) - 1 == len(hull_x))

ans = [max_query(x, K, M, hull_i, hull_x) for x in X]
assert(ans == brute_ans)

def test_convex_line_hull_float(t = 50000):
import random
random.seed(1337)

for _ in range(t):
n = random.randint(1, 20)
K = [random.randint(-10, 10) for _ in range(n)]
M = [random.randint(-10, 10) for _ in range(n)]
X = [random.uniform(-10, 10) for _ in range(21)]

brute_ans = brute(K, M, X)

hull_i, hull_x = convex_hull_trick(K, M, integer = False)
assert(len(hull_i) - 1 == len(hull_x))

ans = [max_query(x, K, M, hull_i, hull_x) for x in X]
assert(len(ans) == len(brute_ans))
assert(all(abs(x - y) <= 1e-9 for x,y in zip(ans, brute_ans)))

0 comments on commit f4f7a3b

Please sign in to comment.