diff --git a/rust/index/src/spann/utils.rs b/rust/index/src/spann/utils.rs index d7400566b65..bd2035e8bcc 100644 --- a/rust/index/src/spann/utils.rs +++ b/rust/index/src/spann/utils.rs @@ -24,6 +24,7 @@ struct KMeansAlgorithmInput<'referred_data> { } #[allow(dead_code)] +#[derive(Debug)] pub struct KMeansAlgorithmOutput { cluster_centers: Vec>, cluster_counts: Vec, @@ -367,9 +368,9 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> { let mut previous_counts = vec![]; for _ in 0..NUM_ITERS_FOR_MAIN_LOOP { // Prepare for the next iteration. - let previous_centers = current_centers; - let previous_counts = current_counts; - self.input.indices.shuffle(&mut rand::thread_rng()); + previous_centers = current_centers; + previous_counts = current_counts; + self.input.indices[self.input.first..self.input.last].shuffle(&mut rand::thread_rng()); let mut kmeans_assign = self.kmeansassign_for_main_loop( &previous_centers, &previous_counts, @@ -429,17 +430,18 @@ mod tests { // 2D embeddings. let dim = 2; let embeddings = [ - 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0, + -1.0, -1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, + 11.0, 11.0, 12.0, 12.0, ]; - let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let kmeans_algo = KMeansAlgorithm::new( indices, &embeddings, dim, 2, - 0, + 1, + 10, 8, - 1000, chroma_distance::DistanceFunction::Euclidean, 100.0, ); @@ -456,18 +458,18 @@ mod tests { // 2D embeddings. let dim = 2; let embeddings = [ - 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 5.0, 5.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, - 11.0, 11.0, + -1.0, -1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 5.0, 5.0, 10.0, 10.0, 11.0, 10.0, + 10.0, 11.0, 11.0, 11.0, 12.0, 12.0, 13.0, 13.0, ]; - let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8]; + let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let kmeans_algo = KMeansAlgorithm::new( indices, &embeddings, dim, 2, - 0, + 1, + 12, 9, - 1000, chroma_distance::DistanceFunction::Euclidean, 100.0, ); @@ -479,7 +481,7 @@ mod tests { assert_eq!(res.cluster_counts, vec![5, 4]); assert_eq!(res.total_distance, 94.0); assert_eq!(res.cluster_farthest_distance, vec![50.0, 11.0]); - assert_eq!(res.cluster_farthest_point_idx, vec![4, 8]); + assert_eq!(res.cluster_farthest_point_idx, vec![5, 9]); assert_eq!( res.cluster_new_centers, vec![vec![7.0, 7.0], vec![42.0, 42.0]] @@ -501,7 +503,7 @@ mod tests { 2, 0, 8, - 1000, + 4, chroma_distance::DistanceFunction::Euclidean, 100.0, ); @@ -526,4 +528,38 @@ mod tests { assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]); assert_eq!(res.cluster_labels, labels); } + + #[test] + fn test_kmeans_clustering() { + // 2D embeddings. + let dim = 2; + let embeddings = [ + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0, + ]; + let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let mut kmeans_algo = KMeansAlgorithm::new( + indices, + &embeddings, + dim, + 2, + 0, + 8, + 4, + chroma_distance::DistanceFunction::Euclidean, + 100.0, + ); + let res = kmeans_algo.cluster(); + assert_eq!(res.cluster_counts, vec![4, 4]); + let mut labels = HashMap::new(); + labels.insert(0, res.cluster_labels[&0]); + labels.insert(1, res.cluster_labels[&0]); + labels.insert(2, res.cluster_labels[&0]); + labels.insert(3, res.cluster_labels[&0]); + labels.insert(4, res.cluster_labels[&4]); + labels.insert(5, res.cluster_labels[&4]); + labels.insert(6, res.cluster_labels[&4]); + labels.insert(7, res.cluster_labels[&4]); + assert_eq!(res.cluster_labels, labels); + println!("{:?}", res); + } }