Skip to content

Commit

Permalink
Merge pull request #5 from geocrystal/fix/nearest
Browse files Browse the repository at this point in the history
fix nearest points
  • Loading branch information
mamantoha authored Apr 11, 2024
2 parents 461289e + dc4c1f1 commit 84d485a
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 66 deletions.
15 changes: 13 additions & 2 deletions .ameba.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
# This configuration file was generated by `ameba --gen-config`
# on 2024-03-29 12:41:29 UTC using Ameba version 1.6.1.
# on 2024-04-11 17:06:41 UTC using Ameba version 1.6.1.
# The point is for the user to remove these configuration records
# one by one as the reported problems are removed from the code base.

# Problems found: 4
# Problems found: 2
# Run `ameba --only Lint/NotNil` for details
Lint/NotNil:
Description: Identifies usage of `not_nil!` calls
Excluded:
- benchmark/benchmark.cr
Enabled: true
Severity: Warning

# Problems found: 3
# Run `ameba --only Naming/BlockParameterName` for details
Naming/BlockParameterName:
Description: Disallows non-descriptive block parameter names
MinNameLength: 3
AllowNamesEndingInNumbers: true
Excluded:
- benchmark/benchmark.cr
- src/kd_tree.cr
- spec/kd_tree_spec.cr
AllowedNames:
- _
- e
Expand Down
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ kd_tree.nearest([1.0, 1.0], 2)
Using a tree with 1 million points `[x, y] of Float64` on my i7-8550U CPU @ 1.80GHz:

```console
build(init) ~10 seconds
nearest point 00.000278579
nearest point 5 00.000693038
nearest point 50 00.007207470
nearest point 255 00.134533902
nearest point 999 08.510465131
Benchmarking KD-Tree with 1 million points
build(init): 3.41 seconds
user system total real
nearest point 1 0.000019 0.000000 0.000019 ( 0.000019)
nearest point 5 0.000021 0.000000 0.000021 ( 0.000021)
nearest point 10 0.000025 0.000001 0.000026 ( 0.000025)
nearest point 50 0.000269 0.000002 0.000271 ( 0.000272)
nearest point 100 0.000809 0.000000 0.000809 ( 0.000812)
nearest point 255 0.005078 0.000000 0.005078 ( 0.005087)
nearest point 999 0.439598 0.000001 0.439599 ( 0.440699)
```

## Contributing
Expand Down
26 changes: 26 additions & 0 deletions benchmark/benchmark.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
require "benchmark"
require "../src/kd_tree"

# Generate 1 million random points
points = Array.new(1_000_000) { [rand * 100.0, rand * 100.0] }

puts "Benchmarking KD-Tree with 1 million points"
tree = nil
build_time = Benchmark.measure {
tree = Kd::Tree(Float64).new(points)
}

puts "build(init): #{build_time.total.round(2)} seconds"

tree = tree.not_nil!

# Define a test point to find nearest neighbors for
test_point = [50.0, 50.0]

Benchmark.bm do |x|
[1, 5, 10, 50, 100, 255, 999].each do |n|
x.report("nearest point #{n.to_s.rjust(3, ' ')}") do
tree.not_nil!.nearest(test_point, n)
end
end
end
31 changes: 31 additions & 0 deletions spec/kd_tree_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,35 @@ describe Kd::Tree do
res.should eq([[2.0, 3.0, 0.0], [5.0, 4.0, 0.0]])
end
end

describe "#nearest" do
# https://github.com/geocrystal/kd_tree/issues/2
it "should equal naive implementation" do
ndim = 2
k = 3
distance = ->(m : Array(Float64), n : Array(Float64)) do
m.each_with_index.reduce(0) do |sum, (coord, index)|
sum += (coord - n[index]) ** 2
sum
end
end

10.times do
points = Array.new(10) do
Array.new(ndim) do
rand(-10.0..10.0)
end
end
kd_tree = Kd::Tree(Float64).new(points)
target = Array.new(ndim) do
rand(-11.0..11.0)
end
res = kd_tree.nearest(target, k)
sorted = points.sort_by do |p|
distance.call(p, target)
end.reverse!
(res - sorted[-k..]).should eq [] of Float64
end
end
end
end
96 changes: 38 additions & 58 deletions src/kd_tree.cr
Original file line number Diff line number Diff line change
Expand Up @@ -14,91 +14,71 @@ module Kd
end
end

getter root

@root : Node(T)?
getter root : Node(T)?
@k : Int32

def initialize(points : Array(Array(T)), depth = 0)
def initialize(points : Array(Array(T)))
@k = points.first.size # assumes all points have the same dimension
@root = build(points, depth)
@root = build_tree(points, 0)
end

def nearest(query : Array(T))
nearest(query, 1)
end
def nearest(target : Array(T), n : Int32 = 1) : Array(Array(T))
return [] of Array(T) if n < 1

best_nodes = Array(Node(T)).new

def nearest(query : Array(T), n : Int32)
nearest!(@root, query, [] of Array(T), n)
find_n_nearest(@root, target, 0, best_nodes, n)

best_nodes.map(&.pivot)
end

private def nearest!(
curr : Node?,
query : Array(T),
nearest : Array(Array(T)),
n = 1
private def find_n_nearest(
node : Node(T)?,
target : Array(T),
depth : Int32,
best_nodes : Array(Node(T)),
n : Int32
)
return nearest if curr.nil?

# if the current node is better than any of the current nearest,
# then it becomes a current nearest
if nearest.size < n
nearest << curr.pivot
else
dist_curr_query = distance(curr.pivot, query)
ix = nearest.index { |b| dist_curr_query < distance(b, query) }
nearest[ix] = curr.pivot if ix
end
return unless node

# determine which branch contains the query along the split dimension
nearer, farther = if query[curr.split] <= curr.pivot[curr.split]
[curr.left, curr.right]
else
[curr.right, curr.left]
end

# search the nearer branch
nearest = nearest!(nearer, query, nearest, n)

# search the farther branch if the distance to the hyperplane is less
# than any nearest so far
dist_curr_query_spldim = distance(
[curr.pivot[curr.split]],
[query[curr.split]]
)
axis = depth % @k

next_node = target[axis] < node.pivot[axis] ? node.left : node.right
other_node = target[axis] < node.pivot[axis] ? node.right : node.left

find_n_nearest(next_node, target, depth + 1, best_nodes, n)

if nearest.find { |b| distance(b, query) >= dist_curr_query_spldim }
nearest = nearest!(farther, query, nearest, n)
if best_nodes.size < n || distance(target, node.pivot) < distance(target, best_nodes.last.pivot)
best_nodes << node
best_nodes.sort_by! { |nd| distance(target, nd.pivot) }
best_nodes.pop if best_nodes.size > n
end

# else no need to search the entire farther branch i.e. prune!
nearest
if other_node && (best_nodes.size < n || (target[axis] - node.pivot[axis]).abs**2 < distance(target, best_nodes.last.pivot))
find_n_nearest(other_node, target, depth + 1, best_nodes, n)
end
end

private def distance(m : Array(T), n : Array(T))
# squared euclidean distance (to avoid expensive sqrt operation)
m.each_with_index.reduce(0) do |sum, (coord, index)|
sum += (coord - n[index]) ** 2
sum
m.each_with_index.sum do |coord, index|
(coord - n[index]) ** 2
end
end

private def build(points : Array(Array(T)), depth = 0)
private def build_tree(points : Array(Array(T)), depth : Int32) : Node(T)?
return if points.empty?

# Select axis based on depth so that axis cycles through all valid values
axis = depth % @k

# Sort point list and choose median as pivot element
points.sort! { |m, n| m[axis] <=> n[axis] }
pivot = points.size // 2
points.sort_by! { |point| point[axis] }
median = points.size // 2

# Create node and construct subtrees
Node(T).new(
points[pivot],
points[median],
axis,
build(points[0...pivot], depth + 1),
build(points[pivot + 1..-1], depth + 1)
build_tree(points[0...median], depth + 1),
build_tree(points[median + 1..], depth + 1)
)
end
end
Expand Down

0 comments on commit 84d485a

Please sign in to comment.