diff --git a/.ameba.yml b/.ameba.yml index 23ebf93..2815a78 100644 --- a/.ameba.yml +++ b/.ameba.yml @@ -1,16 +1,27 @@ # This configuration file was generated by `ameba --gen-config` -# on 2024-03-29 12:41:29 UTC using Ameba version 1.6.1. +# on 2024-04-11 17:06:41 UTC using Ameba version 1.6.1. # The point is for the user to remove these configuration records # one by one as the reported problems are removed from the code base. -# Problems found: 4 +# Problems found: 2 +# Run `ameba --only Lint/NotNil` for details +Lint/NotNil: + Description: Identifies usage of `not_nil!` calls + Excluded: + - benchmark/benchmark.cr + Enabled: true + Severity: Warning + +# Problems found: 3 # Run `ameba --only Naming/BlockParameterName` for details Naming/BlockParameterName: Description: Disallows non-descriptive block parameter names MinNameLength: 3 AllowNamesEndingInNumbers: true Excluded: + - benchmark/benchmark.cr - src/kd_tree.cr + - spec/kd_tree_spec.cr AllowedNames: - _ - e diff --git a/README.md b/README.md index 34c5ed2..347c101 100644 --- a/README.md +++ b/README.md @@ -70,12 +70,16 @@ kd_tree.nearest([1.0, 1.0], 2) Using a tree with 1 million points `[x, y] of Float64` on my i7-8550U CPU @ 1.80GHz: ```console -build(init) ~10 seconds -nearest point 00.000278579 -nearest point 5 00.000693038 -nearest point 50 00.007207470 -nearest point 255 00.134533902 -nearest point 999 08.510465131 +Benchmarking KD-Tree with 1 million points +build(init): 3.41 seconds + user system total real +nearest point 1 0.000019 0.000000 0.000019 ( 0.000019) +nearest point 5 0.000021 0.000000 0.000021 ( 0.000021) +nearest point 10 0.000025 0.000001 0.000026 ( 0.000025) +nearest point 50 0.000269 0.000002 0.000271 ( 0.000272) +nearest point 100 0.000809 0.000000 0.000809 ( 0.000812) +nearest point 255 0.005078 0.000000 0.005078 ( 0.005087) +nearest point 999 0.439598 0.000001 0.439599 ( 0.440699) ``` ## Contributing diff --git a/benchmark/benchmark.cr b/benchmark/benchmark.cr new file mode 100644 index 0000000..198be2e --- /dev/null +++ b/benchmark/benchmark.cr @@ -0,0 +1,26 @@ +require "benchmark" +require "../src/kd_tree" + +# Generate 1 million random points +points = Array.new(1_000_000) { [rand * 100.0, rand * 100.0] } + +puts "Benchmarking KD-Tree with 1 million points" +tree = nil +build_time = Benchmark.measure { + tree = Kd::Tree(Float64).new(points) +} + +puts "build(init): #{build_time.total.round(2)} seconds" + +tree = tree.not_nil! + +# Define a test point to find nearest neighbors for +test_point = [50.0, 50.0] + +Benchmark.bm do |x| + [1, 5, 10, 50, 100, 255, 999].each do |n| + x.report("nearest point #{n.to_s.rjust(3, ' ')}") do + tree.not_nil!.nearest(test_point, n) + end + end +end diff --git a/spec/kd_tree_spec.cr b/spec/kd_tree_spec.cr index c9c32e5..a6320b1 100644 --- a/spec/kd_tree_spec.cr +++ b/spec/kd_tree_spec.cr @@ -104,4 +104,35 @@ describe Kd::Tree do res.should eq([[2.0, 3.0, 0.0], [5.0, 4.0, 0.0]]) end end + + describe "#nearest" do + # https://github.com/geocrystal/kd_tree/issues/2 + it "should equal naive implementation" do + ndim = 2 + k = 3 + distance = ->(m : Array(Float64), n : Array(Float64)) do + m.each_with_index.reduce(0) do |sum, (coord, index)| + sum += (coord - n[index]) ** 2 + sum + end + end + + 10.times do + points = Array.new(10) do + Array.new(ndim) do + rand(-10.0..10.0) + end + end + kd_tree = Kd::Tree(Float64).new(points) + target = Array.new(ndim) do + rand(-11.0..11.0) + end + res = kd_tree.nearest(target, k) + sorted = points.sort_by do |p| + distance.call(p, target) + end.reverse! + (res - sorted[-k..]).should eq [] of Float64 + end + end + end end diff --git a/src/kd_tree.cr b/src/kd_tree.cr index ae71170..1f920ba 100644 --- a/src/kd_tree.cr +++ b/src/kd_tree.cr @@ -14,91 +14,71 @@ module Kd end end - getter root - - @root : Node(T)? + getter root : Node(T)? @k : Int32 - def initialize(points : Array(Array(T)), depth = 0) + def initialize(points : Array(Array(T))) @k = points.first.size # assumes all points have the same dimension - @root = build(points, depth) + @root = build_tree(points, 0) end - def nearest(query : Array(T)) - nearest(query, 1) - end + def nearest(target : Array(T), n : Int32 = 1) : Array(Array(T)) + return [] of Array(T) if n < 1 + + best_nodes = Array(Node(T)).new - def nearest(query : Array(T), n : Int32) - nearest!(@root, query, [] of Array(T), n) + find_n_nearest(@root, target, 0, best_nodes, n) + + best_nodes.map(&.pivot) end - private def nearest!( - curr : Node?, - query : Array(T), - nearest : Array(Array(T)), - n = 1 + private def find_n_nearest( + node : Node(T)?, + target : Array(T), + depth : Int32, + best_nodes : Array(Node(T)), + n : Int32 ) - return nearest if curr.nil? - - # if the current node is better than any of the current nearest, - # then it becomes a current nearest - if nearest.size < n - nearest << curr.pivot - else - dist_curr_query = distance(curr.pivot, query) - ix = nearest.index { |b| dist_curr_query < distance(b, query) } - nearest[ix] = curr.pivot if ix - end + return unless node - # determine which branch contains the query along the split dimension - nearer, farther = if query[curr.split] <= curr.pivot[curr.split] - [curr.left, curr.right] - else - [curr.right, curr.left] - end - - # search the nearer branch - nearest = nearest!(nearer, query, nearest, n) - - # search the farther branch if the distance to the hyperplane is less - # than any nearest so far - dist_curr_query_spldim = distance( - [curr.pivot[curr.split]], - [query[curr.split]] - ) + axis = depth % @k + + next_node = target[axis] < node.pivot[axis] ? node.left : node.right + other_node = target[axis] < node.pivot[axis] ? node.right : node.left + + find_n_nearest(next_node, target, depth + 1, best_nodes, n) - if nearest.find { |b| distance(b, query) >= dist_curr_query_spldim } - nearest = nearest!(farther, query, nearest, n) + if best_nodes.size < n || distance(target, node.pivot) < distance(target, best_nodes.last.pivot) + best_nodes << node + best_nodes.sort_by! { |nd| distance(target, nd.pivot) } + best_nodes.pop if best_nodes.size > n end - # else no need to search the entire farther branch i.e. prune! - nearest + if other_node && (best_nodes.size < n || (target[axis] - node.pivot[axis]).abs**2 < distance(target, best_nodes.last.pivot)) + find_n_nearest(other_node, target, depth + 1, best_nodes, n) + end end private def distance(m : Array(T), n : Array(T)) # squared euclidean distance (to avoid expensive sqrt operation) - m.each_with_index.reduce(0) do |sum, (coord, index)| - sum += (coord - n[index]) ** 2 - sum + m.each_with_index.sum do |coord, index| + (coord - n[index]) ** 2 end end - private def build(points : Array(Array(T)), depth = 0) + private def build_tree(points : Array(Array(T)), depth : Int32) : Node(T)? return if points.empty? - # Select axis based on depth so that axis cycles through all valid values axis = depth % @k - - # Sort point list and choose median as pivot element - points.sort! { |m, n| m[axis] <=> n[axis] } - pivot = points.size // 2 + points.sort_by! { |point| point[axis] } + median = points.size // 2 # Create node and construct subtrees Node(T).new( - points[pivot], + points[median], axis, - build(points[0...pivot], depth + 1), - build(points[pivot + 1..-1], depth + 1) + build_tree(points[0...median], depth + 1), + build_tree(points[median + 1..], depth + 1) ) end end