Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Dec 4, 2024
1 parent d7395f6 commit 6262faa
Showing 1 changed file with 121 additions and 2 deletions.
123 changes: 121 additions & 2 deletions rust/index/src/spann/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct KMeansAlgorithm<'referred_data> {
input: KMeansAlgorithmInput<'referred_data>,
}

#[derive(Debug)]
struct KMeansAssignForCenterInitOutput {
cluster_counts: Vec<usize>,
cluster_weighted_counts: Vec<f32>,
Expand All @@ -42,6 +43,7 @@ struct KMeansAssignForCenterInitOutput {
}

#[allow(dead_code)]
#[derive(Debug)]
struct KMeansAssignForMainLoopOutput {
cluster_counts: Vec<usize>,
cluster_farthest_point_idx: Vec<i32>,
Expand All @@ -51,6 +53,7 @@ struct KMeansAssignForMainLoopOutput {
}

#[allow(dead_code)]
#[derive(Debug)]
struct KMeansAssignFinishOutput {
cluster_counts: Vec<usize>,
cluster_nearest_point_idx: Vec<i32>,
Expand Down Expand Up @@ -125,8 +128,11 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> {
// Assign all points the nearest center.
// TODO(Sanket): Normalize the points if needed for cosine similarity.
// TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc.
// Actual value of previous_counts does not matter since lambda is 0.
let previous_counts = vec![0; self.input.k];
for idx in self.input.first..batch_end {
let (min_center, min_distance) = self.get_nearest_center(centers, idx, 0.0, &[]);
let (min_center, min_distance) =
self.get_nearest_center(centers, idx, 0.0, &previous_counts);
total_distance += min_distance;
cluster_counts[min_center as usize] += 1;
cluster_weighted_counts[min_center as usize] += min_distance;
Expand Down Expand Up @@ -203,8 +209,11 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> {
// Assign all points the nearest center.
// TODO(Sanket): Normalize the points if needed for cosine similarity.
// TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc.
// The actual value of previous_counts does not matter since lambda is 0.
let previous_counts = vec![0; self.input.k];
for idx in self.input.first..batch_end {
let (min_center, min_distance) = self.get_nearest_center(centers, idx, 0.0, &[]);
let (min_center, min_distance) =
self.get_nearest_center(centers, idx, 0.0, &previous_counts);
cluster_counts[min_center as usize] += 1;
let point_idx = self.input.indices[idx];
if min_distance <= cluster_nearest_distance[min_center as usize] {
Expand Down Expand Up @@ -408,3 +417,113 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> {
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use super::KMeansAlgorithm;

#[test]
fn test_kmeans_assign_for_center_init() {
// 2D embeddings.
let dim = 2;
let embeddings = [
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0,
];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let kmeans_algo = KMeansAlgorithm::new(
indices,
&embeddings,
dim,
2,
0,
8,
1000,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let res = kmeans_algo.kmeansassign_for_centerinit(&centers);
assert_eq!(res.cluster_counts, vec![4, 4]);
assert_eq!(res.total_distance, 8.0);
assert_eq!(res.cluster_weighted_counts, vec![4.0, 4.0]);
assert_eq!(res.cluster_farthest_distance, vec![2.0, 2.0]);
}

#[test]
fn test_kmeans_assign_for_main_loop() {
// 2D embeddings.
let dim = 2;
let embeddings = [
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 5.0, 5.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0,
11.0, 11.0,
];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
let kmeans_algo = KMeansAlgorithm::new(
indices,
&embeddings,
dim,
2,
0,
9,
1000,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
// Penalize [10.0, 10.0] so that [5.0, 5.0] gets assigned to [0.0, 0.0]
let previous_counts = vec![0, 9];
let lambda = 1.0;
let res = kmeans_algo.kmeansassign_for_main_loop(&centers, &previous_counts, lambda);
assert_eq!(res.cluster_counts, vec![5, 4]);
assert_eq!(res.total_distance, 94.0);
assert_eq!(res.cluster_farthest_distance, vec![50.0, 11.0]);
assert_eq!(res.cluster_farthest_point_idx, vec![4, 8]);
assert_eq!(
res.cluster_new_centers,
vec![vec![7.0, 7.0], vec![42.0, 42.0]]
);
}

#[test]
fn test_kmeans_assign_finish() {
// 2D embeddings.
let dim = 2;
let embeddings = [
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0,
];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let kmeans_algo = KMeansAlgorithm::new(
indices,
&embeddings,
dim,
2,
0,
8,
1000,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let mut res = kmeans_algo.kmeansassign_finish(&centers, false);
assert_eq!(res.cluster_counts, vec![4, 4]);
assert_eq!(res.cluster_nearest_distance, vec![0.0, 0.0]);
assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]);
assert_eq!(res.cluster_labels.len(), 0);
res = kmeans_algo.kmeansassign_finish(&centers, true);
let mut labels = HashMap::new();
labels.insert(0, 0);
labels.insert(1, 0);
labels.insert(2, 0);
labels.insert(3, 0);
labels.insert(4, 1);
labels.insert(5, 1);
labels.insert(6, 1);
labels.insert(7, 1);
assert_eq!(res.cluster_counts, vec![4, 4]);
assert_eq!(res.cluster_nearest_distance, vec![0.0, 0.0]);
assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]);
assert_eq!(res.cluster_labels, labels);
}
}

0 comments on commit 6262faa

Please sign in to comment.