From 8b900fde1f7124e24da8959463ac217bba810b81 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Mon, 18 Nov 2024 20:16:16 -0800 Subject: [PATCH] Add tests --- rust/index/src/spann/utils.rs | 123 +++++++++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 2 deletions(-) diff --git a/rust/index/src/spann/utils.rs b/rust/index/src/spann/utils.rs index d92249affe1..d7400566b65 100644 --- a/rust/index/src/spann/utils.rs +++ b/rust/index/src/spann/utils.rs @@ -34,6 +34,7 @@ pub struct KMeansAlgorithm<'referred_data> { input: KMeansAlgorithmInput<'referred_data>, } +#[derive(Debug)] struct KMeansAssignForCenterInitOutput { cluster_counts: Vec, cluster_weighted_counts: Vec, @@ -42,6 +43,7 @@ struct KMeansAssignForCenterInitOutput { } #[allow(dead_code)] +#[derive(Debug)] struct KMeansAssignForMainLoopOutput { cluster_counts: Vec, cluster_farthest_point_idx: Vec, @@ -51,6 +53,7 @@ struct KMeansAssignForMainLoopOutput { } #[allow(dead_code)] +#[derive(Debug)] struct KMeansAssignFinishOutput { cluster_counts: Vec, cluster_nearest_point_idx: Vec, @@ -125,8 +128,11 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> { // Assign all points the nearest center. // TODO(Sanket): Normalize the points if needed for cosine similarity. // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. + // Actual value of previous_counts does not matter since lambda is 0. + let previous_counts = vec![0; self.input.k]; for idx in self.input.first..batch_end { - let (min_center, min_distance) = self.get_nearest_center(centers, idx, 0.0, &[]); + let (min_center, min_distance) = + self.get_nearest_center(centers, idx, 0.0, &previous_counts); total_distance += min_distance; cluster_counts[min_center as usize] += 1; cluster_weighted_counts[min_center as usize] += min_distance; @@ -203,8 +209,11 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> { // Assign all points the nearest center. // TODO(Sanket): Normalize the points if needed for cosine similarity. // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. + // The actual value of previous_counts does not matter since lambda is 0. + let previous_counts = vec![0; self.input.k]; for idx in self.input.first..batch_end { - let (min_center, min_distance) = self.get_nearest_center(centers, idx, 0.0, &[]); + let (min_center, min_distance) = + self.get_nearest_center(centers, idx, 0.0, &previous_counts); cluster_counts[min_center as usize] += 1; let point_idx = self.input.indices[idx]; if min_distance <= cluster_nearest_distance[min_center as usize] { @@ -408,3 +417,113 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> { } } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::KMeansAlgorithm; + + #[test] + fn test_kmeans_assign_for_center_init() { + // 2D embeddings. + let dim = 2; + let embeddings = [ + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0, + ]; + let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let kmeans_algo = KMeansAlgorithm::new( + indices, + &embeddings, + dim, + 2, + 0, + 8, + 1000, + chroma_distance::DistanceFunction::Euclidean, + 100.0, + ); + let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]]; + let res = kmeans_algo.kmeansassign_for_centerinit(¢ers); + assert_eq!(res.cluster_counts, vec![4, 4]); + assert_eq!(res.total_distance, 8.0); + assert_eq!(res.cluster_weighted_counts, vec![4.0, 4.0]); + assert_eq!(res.cluster_farthest_distance, vec![2.0, 2.0]); + } + + #[test] + fn test_kmeans_assign_for_main_loop() { + // 2D embeddings. + let dim = 2; + let embeddings = [ + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 5.0, 5.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, + 11.0, 11.0, + ]; + let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8]; + let kmeans_algo = KMeansAlgorithm::new( + indices, + &embeddings, + dim, + 2, + 0, + 9, + 1000, + chroma_distance::DistanceFunction::Euclidean, + 100.0, + ); + let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]]; + // Penalize [10.0, 10.0] so that [5.0, 5.0] gets assigned to [0.0, 0.0] + let previous_counts = vec![0, 9]; + let lambda = 1.0; + let res = kmeans_algo.kmeansassign_for_main_loop(¢ers, &previous_counts, lambda); + assert_eq!(res.cluster_counts, vec![5, 4]); + assert_eq!(res.total_distance, 94.0); + assert_eq!(res.cluster_farthest_distance, vec![50.0, 11.0]); + assert_eq!(res.cluster_farthest_point_idx, vec![4, 8]); + assert_eq!( + res.cluster_new_centers, + vec![vec![7.0, 7.0], vec![42.0, 42.0]] + ); + } + + #[test] + fn test_kmeans_assign_finish() { + // 2D embeddings. + let dim = 2; + let embeddings = [ + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0, + ]; + let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let kmeans_algo = KMeansAlgorithm::new( + indices, + &embeddings, + dim, + 2, + 0, + 8, + 1000, + chroma_distance::DistanceFunction::Euclidean, + 100.0, + ); + let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]]; + let mut res = kmeans_algo.kmeansassign_finish(¢ers, false); + assert_eq!(res.cluster_counts, vec![4, 4]); + assert_eq!(res.cluster_nearest_distance, vec![0.0, 0.0]); + assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]); + assert_eq!(res.cluster_labels.len(), 0); + res = kmeans_algo.kmeansassign_finish(¢ers, true); + let mut labels = HashMap::new(); + labels.insert(0, 0); + labels.insert(1, 0); + labels.insert(2, 0); + labels.insert(3, 0); + labels.insert(4, 1); + labels.insert(5, 1); + labels.insert(6, 1); + labels.insert(7, 1); + assert_eq!(res.cluster_counts, vec![4, 4]); + assert_eq!(res.cluster_nearest_distance, vec![0.0, 0.0]); + assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]); + assert_eq!(res.cluster_labels, labels); + } +}