Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Dec 4, 2024
1 parent 9750ee6 commit d3bf83a
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions rust/index/src/spann/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct KMeansAlgorithmInput<'referred_data> {
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct KMeansAlgorithmOutput {
cluster_centers: Vec<Vec<f32>>,
cluster_counts: Vec<usize>,
Expand Down Expand Up @@ -367,9 +368,9 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> {
let mut previous_counts = vec![];
for _ in 0..NUM_ITERS_FOR_MAIN_LOOP {
// Prepare for the next iteration.
let previous_centers = current_centers;
let previous_counts = current_counts;
self.input.indices.shuffle(&mut rand::thread_rng());
previous_centers = current_centers;
previous_counts = current_counts;
self.input.indices[self.input.first..self.input.last].shuffle(&mut rand::thread_rng());
let mut kmeans_assign = self.kmeansassign_for_main_loop(
&previous_centers,
&previous_counts,
Expand Down Expand Up @@ -429,17 +430,18 @@ mod tests {
// 2D embeddings.
let dim = 2;
let embeddings = [
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0,
-1.0, -1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0,
11.0, 11.0, 12.0, 12.0,
];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let kmeans_algo = KMeansAlgorithm::new(
indices,
&embeddings,
dim,
2,
0,
1,
10,
8,
1000,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
Expand All @@ -456,18 +458,18 @@ mod tests {
// 2D embeddings.
let dim = 2;
let embeddings = [
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 5.0, 5.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0,
11.0, 11.0,
-1.0, -1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 5.0, 5.0, 10.0, 10.0, 11.0, 10.0,
10.0, 11.0, 11.0, 11.0, 12.0, 12.0, 13.0, 13.0,
];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let kmeans_algo = KMeansAlgorithm::new(
indices,
&embeddings,
dim,
2,
0,
1,
12,
9,
1000,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
Expand All @@ -479,7 +481,7 @@ mod tests {
assert_eq!(res.cluster_counts, vec![5, 4]);
assert_eq!(res.total_distance, 94.0);
assert_eq!(res.cluster_farthest_distance, vec![50.0, 11.0]);
assert_eq!(res.cluster_farthest_point_idx, vec![4, 8]);
assert_eq!(res.cluster_farthest_point_idx, vec![5, 9]);
assert_eq!(
res.cluster_new_centers,
vec![vec![7.0, 7.0], vec![42.0, 42.0]]
Expand All @@ -501,7 +503,7 @@ mod tests {
2,
0,
8,
1000,
4,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
Expand All @@ -526,4 +528,38 @@ mod tests {
assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]);
assert_eq!(res.cluster_labels, labels);
}

#[test]
fn test_kmeans_clustering() {
// 2D embeddings.
let dim = 2;
let embeddings = [
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0,
];
let indices: Vec<u32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let mut kmeans_algo = KMeansAlgorithm::new(
indices,
&embeddings,
dim,
2,
0,
8,
4,
chroma_distance::DistanceFunction::Euclidean,
100.0,
);
let res = kmeans_algo.cluster();
assert_eq!(res.cluster_counts, vec![4, 4]);
let mut labels = HashMap::new();
labels.insert(0, res.cluster_labels[&0]);
labels.insert(1, res.cluster_labels[&0]);
labels.insert(2, res.cluster_labels[&0]);
labels.insert(3, res.cluster_labels[&0]);
labels.insert(4, res.cluster_labels[&4]);
labels.insert(5, res.cluster_labels[&4]);
labels.insert(6, res.cluster_labels[&4]);
labels.insert(7, res.cluster_labels[&4]);
assert_eq!(res.cluster_labels, labels);
println!("{:?}", res);
}
}

0 comments on commit d3bf83a

Please sign in to comment.