From 3c9f1f1278b69394273c3aeb6f9db7fbfc78350d Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Thu, 5 Dec 2024 00:01:14 -0800 Subject: [PATCH] Review comments --- rust/index/src/spann/utils.rs | 755 +++++++++++++++++++--------------- 1 file changed, 416 insertions(+), 339 deletions(-) diff --git a/rust/index/src/spann/utils.rs b/rust/index/src/spann/utils.rs index bd2035e8bcc..d00e89ddb35 100644 --- a/rust/index/src/spann/utils.rs +++ b/rust/index/src/spann/utils.rs @@ -1,16 +1,36 @@ use std::{cmp::min, collections::HashMap}; use chroma_distance::DistanceFunction; +use chroma_error::{ChromaError, ErrorCodes}; use rand::{seq::SliceRandom, Rng}; +use thiserror::Error; // TODO(Sanket): I don't understand why the reference implementation defined // max_distance this way. +// TODO(Sanket): Make these configurable. const MAX_DISTANCE: f32 = f32::MAX / 10.0; const NUM_ITERS_FOR_CENTER_INIT: usize = 3; const NUM_ITERS_FOR_MAIN_LOOP: usize = 100; const NUM_ITERS_NO_IMPROVEMENT: usize = 5; -struct KMeansAlgorithmInput<'referred_data> { +/// The input for kmeans algorithm. +/// - indices: The indices of the embeddings that we want to cluster. +/// - embeddings: The entire list of embeddings. We only cluster a subset from this +/// list based on the indices. This is a flattened out list so for e.g. +/// the first embedding will be stored from 0..embedding_dimension, the second +/// from embedding_dimension..2*embedding_dimension and so on. +/// - embedding_dimension: The dimension of the embeddings. +/// - k: The number of clusters. +/// - first: The start index in the indices array from where we start clustering. +/// - last: The end index in the indices array till where we cluster. It excludes this index. +/// - num_samples: Each run of kmeans only clusters num_samples number of points. This is +/// done to speed up clustering without losing much accuracy. In the end, we cluster all +/// the points. +/// - distance_function: The distance function to use for clustering. +/// - initial_lambda: Lambda is a parameter used to penalize large clusters. This is used +/// to generate balanced clusters. The algorithm generate a lambda on the fly using this +/// initial_lambda as the starting point. +pub struct KMeansAlgorithmInput<'referred_data> { indices: Vec, embeddings: &'referred_data [f32], embedding_dimension: usize, @@ -23,6 +43,39 @@ struct KMeansAlgorithmInput<'referred_data> { initial_lambda: f32, } +impl<'referred_data> KMeansAlgorithmInput<'referred_data> { + #[allow(clippy::too_many_arguments)] + pub fn new( + indices: Vec, + embeddings: &'referred_data [f32], + embedding_dimension: usize, + k: usize, + first: usize, + last: usize, + num_samples: usize, + distance_function: DistanceFunction, + initial_lambda: f32, + ) -> Self { + KMeansAlgorithmInput { + indices, + embeddings, + embedding_dimension, + k, + first, + last, + num_samples, + distance_function, + initial_lambda, + } + } +} + +/// The output from kmeans. +/// - cluster_centers: The embeddings of the centers of the clusters. +/// - cluster_counts: The number of points in each cluster. +/// - cluster_labels: The mapping of each point to the cluster it belongs to. Clusters are +/// identified by unsigned integers starting from 0. These ids are also indexes in the +/// cluster_centers and cluster_counts arrays. #[allow(dead_code)] #[derive(Debug)] pub struct KMeansAlgorithmOutput { @@ -31,10 +84,6 @@ pub struct KMeansAlgorithmOutput { cluster_labels: HashMap, } -pub struct KMeansAlgorithm<'referred_data> { - input: KMeansAlgorithmInput<'referred_data>, -} - #[derive(Debug)] struct KMeansAssignForCenterInitOutput { cluster_counts: Vec, @@ -62,368 +111,395 @@ struct KMeansAssignFinishOutput { cluster_labels: HashMap, } -impl<'referred_data> KMeansAlgorithm<'referred_data> { - #[allow(clippy::too_many_arguments)] - pub fn new( - indices: Vec, - embeddings: &'referred_data [f32], - embedding_dimension: usize, - k: usize, - first: usize, - last: usize, - num_samples: usize, - distance_function: DistanceFunction, - initial_lambda: f32, - ) -> Self { - KMeansAlgorithm { - input: KMeansAlgorithmInput { - indices, - embeddings, - embedding_dimension, - k, - first, - last, - num_samples, - distance_function, - initial_lambda, - }, +#[derive(Error, Debug)] +pub enum KMeansError { + #[error("There should be at least one cluster")] + MaxClusterNotFound, + #[error("Could not assign a point to a center")] + PointAssignmentFailed, +} + +impl ChromaError for KMeansError { + fn code(&self) -> ErrorCodes { + match self { + Self::MaxClusterNotFound => ErrorCodes::Internal, + Self::PointAssignmentFailed => ErrorCodes::Internal, } } +} - fn get_nearest_center( - &self, - centers: &[Vec], - idx: usize, - lambda: f32, - previous_counts: &[usize], - ) -> (i32, f32) { - let point_idx = self.input.indices[idx]; - let dim = self.input.embedding_dimension; - let start_idx = point_idx * dim as u32; - let end_idx = (point_idx + 1) * dim as u32; - let mut min_distance = MAX_DISTANCE; - let mut min_center: i32 = -1; - for center_idx in 0..self.input.k { - let distance = self.input.distance_function.distance( - &self.input.embeddings[start_idx as usize..end_idx as usize], - ¢ers[center_idx], - ) + lambda * previous_counts[center_idx] as f32; - if distance > -MAX_DISTANCE && distance < min_distance { - min_distance = distance; - min_center = center_idx as i32; - } +// For a given point, get the nearest center and the distance to it. +// lambda is a parameter used to penalize large clusters. +// previous_counts is the number of points in each cluster in the previous iteration. +fn get_nearest_center( + input: &KMeansAlgorithmInput, + centers: &[Vec], + idx: usize, + lambda: f32, + previous_counts: &[usize], +) -> Result<(i32, f32), KMeansError> { + let point_idx = input.indices[idx]; + let dim = input.embedding_dimension; + let start_idx = point_idx as usize * dim; + let end_idx = (point_idx + 1) as usize * dim; + let mut min_distance = MAX_DISTANCE; + let mut min_center: i32 = -1; + for center_idx in 0..input.k { + let distance = input + .distance_function + .distance(&input.embeddings[start_idx..end_idx], ¢ers[center_idx]) + + lambda * previous_counts[center_idx] as f32; + if distance > -MAX_DISTANCE && distance < min_distance { + min_distance = distance; + min_center = center_idx as i32; } - if min_center == -1 { - panic!("Invariant violation. Every point should be assigned to a center."); - } - (min_center, min_distance) } + if min_center == -1 { + return Err(KMeansError::PointAssignmentFailed); + } + Ok((min_center, min_distance)) +} - fn kmeansassign_for_centerinit(&self, centers: &[Vec]) -> KMeansAssignForCenterInitOutput { - // Assign to only a sample. - let batch_end = min(self.input.first + self.input.num_samples, self.input.last); - let mut cluster_counts = vec![0; self.input.k]; - let mut cluster_weighted_counts = vec![0.0; self.input.k]; - let mut cluster_farthest_distance = vec![-MAX_DISTANCE; self.input.k]; - let mut total_distance = 0.0; - // Assign all points the nearest center. - // TODO(Sanket): Normalize the points if needed for cosine similarity. - // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. - // Actual value of previous_counts does not matter since lambda is 0. - let previous_counts = vec![0; self.input.k]; - for idx in self.input.first..batch_end { - let (min_center, min_distance) = - self.get_nearest_center(centers, idx, 0.0, &previous_counts); - total_distance += min_distance; - cluster_counts[min_center as usize] += 1; - cluster_weighted_counts[min_center as usize] += min_distance; - if min_distance > cluster_farthest_distance[min_center as usize] { - cluster_farthest_distance[min_center as usize] = min_distance; - } - } - KMeansAssignForCenterInitOutput { - cluster_counts, - cluster_weighted_counts, - cluster_farthest_distance, - total_distance, +// Center init is a process of choosing the initial centers for the kmeans algorithm. +// This function assigns all the points to their respective nearest centers without any regularization +// i.e. without penalizing large clusters using a regularization parameter like lambda. +fn kmeansassign_for_centerinit( + input: &KMeansAlgorithmInput, + centers: &[Vec], +) -> Result { + // Assign to only a sample. + let batch_end = min(input.first + input.num_samples, input.last); + // Number of points in each cluster. + let mut cluster_counts = vec![0; input.k]; + // Weighted counts are the sum of distances of all points assigned to a cluster. + let mut cluster_weighted_counts = vec![0.0; input.k]; + // Distance of the farthest point from the cluster center. + let mut cluster_farthest_distance = vec![-MAX_DISTANCE; input.k]; + // Sum of distances of all points to their nearest centers. + let mut total_distance = 0.0; + // Assign all points their nearest centers. + // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. + // Actual value of previous_counts does not matter since lambda is 0. + // Passing a vector of 0s. + let previous_counts = vec![0; input.k]; + for idx in input.first..batch_end { + let (min_center, min_distance) = + get_nearest_center(input, centers, idx, /* lambda */ 0.0, &previous_counts)?; + total_distance += min_distance; + cluster_counts[min_center as usize] += 1; + cluster_weighted_counts[min_center as usize] += min_distance; + if min_distance > cluster_farthest_distance[min_center as usize] { + cluster_farthest_distance[min_center as usize] = min_distance; } } + Ok(KMeansAssignForCenterInitOutput { + cluster_counts, + cluster_weighted_counts, + cluster_farthest_distance, + total_distance, + }) +} - fn kmeansassign_for_main_loop( - &self, - centers: &[Vec], - previous_counts: &[usize], - lambda: f32, - ) -> KMeansAssignForMainLoopOutput { - let batch_end = min(self.input.last, self.input.first + self.input.num_samples); - let dim = self.input.embedding_dimension; - let mut cluster_counts = vec![0; self.input.k]; - let mut cluster_farthest_point_idx: Vec = vec![-1; self.input.k]; - let mut cluster_farthest_distance = vec![-MAX_DISTANCE; self.input.k]; - let mut cluster_new_centers = vec![vec![0.0; dim]; self.input.k]; - let mut total_distance = 0.0; - // Assign all points the nearest center. - // TODO(Sanket): Normalize the points if needed for cosine similarity. - // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. - for idx in self.input.first..batch_end { - let (min_center, min_distance) = - self.get_nearest_center(centers, idx, lambda, previous_counts); - total_distance += min_distance; - cluster_counts[min_center as usize] += 1; - let point_idx = self.input.indices[idx]; - if min_distance > cluster_farthest_distance[min_center as usize] { - cluster_farthest_point_idx[min_center as usize] = point_idx as i32; - cluster_farthest_distance[min_center as usize] = min_distance; - } - let start_idx = point_idx * dim as u32; - let end_idx = (point_idx + 1) * dim as u32; - self.input.embeddings[start_idx as usize..end_idx as usize] - .iter() - .enumerate() - .for_each(|(index, emb)| cluster_new_centers[min_center as usize][index] += *emb); - } - KMeansAssignForMainLoopOutput { - cluster_counts, - cluster_farthest_point_idx, - cluster_farthest_distance, - cluster_new_centers, - total_distance, +// This function assigns all the points to their respective nearest centers with regularization +// i.e. it penalizes large clusters using a regularization parameter like lambda. +fn kmeansassign_for_main_loop( + input: &KMeansAlgorithmInput, + centers: &[Vec], + previous_counts: &[usize], + lambda: f32, +) -> Result { + let batch_end = min(input.last, input.first + input.num_samples); + let dim = input.embedding_dimension; + // Number of points in each cluster. + let mut cluster_counts = vec![0; input.k]; + // Index of the farthest point from the cluster center. + let mut cluster_farthest_point_idx: Vec = vec![-1; input.k]; + // Distance of the farthest point from the cluster center. + let mut cluster_farthest_distance = vec![-MAX_DISTANCE; input.k]; + // New centers for each cluster. This is simply the sum of embeddings of all points + // that belong to the cluster. + let mut cluster_new_centers = vec![vec![0.0; dim]; input.k]; + // Sum of distances of all points to their nearest centers. + let mut total_distance = 0.0; + // Assign all points the nearest center. + // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. + for idx in input.first..batch_end { + let (min_center, min_distance) = + get_nearest_center(input, centers, idx, lambda, previous_counts)?; + total_distance += min_distance; + cluster_counts[min_center as usize] += 1; + let point_idx = input.indices[idx]; + if min_distance > cluster_farthest_distance[min_center as usize] { + cluster_farthest_point_idx[min_center as usize] = point_idx as i32; + cluster_farthest_distance[min_center as usize] = min_distance; } + let start_idx = point_idx * dim as u32; + let end_idx = (point_idx + 1) * dim as u32; + input.embeddings[start_idx as usize..end_idx as usize] + .iter() + .enumerate() + .for_each(|(index, emb)| cluster_new_centers[min_center as usize][index] += *emb); } + Ok(KMeansAssignForMainLoopOutput { + cluster_counts, + cluster_farthest_point_idx, + cluster_farthest_distance, + cluster_new_centers, + total_distance, + }) +} - fn kmeansassign_finish( - &self, - centers: &[Vec], - generate_labels: bool, - ) -> KMeansAssignFinishOutput { - // Assign all the points. - let batch_end = self.input.last; - let mut cluster_counts = vec![0; self.input.k]; - let mut cluster_nearest_point_idx: Vec = vec![-1; self.input.k]; - let mut cluster_nearest_distance = vec![MAX_DISTANCE; self.input.k]; - let mut cluster_labels; - if generate_labels { - cluster_labels = HashMap::with_capacity(batch_end - self.input.first); - } else { - cluster_labels = HashMap::new(); - } - // Assign all points the nearest center. - // TODO(Sanket): Normalize the points if needed for cosine similarity. - // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. - // The actual value of previous_counts does not matter since lambda is 0. - let previous_counts = vec![0; self.input.k]; - for idx in self.input.first..batch_end { - let (min_center, min_distance) = - self.get_nearest_center(centers, idx, 0.0, &previous_counts); - cluster_counts[min_center as usize] += 1; - let point_idx = self.input.indices[idx]; - if min_distance <= cluster_nearest_distance[min_center as usize] { - cluster_nearest_distance[min_center as usize] = min_distance; - cluster_nearest_point_idx[min_center as usize] = point_idx as i32; - } - if generate_labels { - cluster_labels.insert(point_idx, min_center); - } +// This is run in the end to assign all points to their nearest centers. +// It does not penalize the clusters via lambda. Also, it assigns ALL the +// points instead of just a sample. +// generate_labels is used to denote if this method is expected to also return +// the assignment labels of the points. +fn kmeansassign_finish( + input: &KMeansAlgorithmInput, + centers: &[Vec], + generate_labels: bool, +) -> Result { + // Assign ALL the points. + let batch_end = input.last; + // Number of points in each cluster. + let mut cluster_counts = vec![0; input.k]; + // Index and Distance of the nearest point from the cluster center. + let mut cluster_nearest_point_idx: Vec = vec![-1; input.k]; + let mut cluster_nearest_distance = vec![MAX_DISTANCE; input.k]; + // Point id -> label id mapping for the cluster assignment. + let mut cluster_labels; + if generate_labels { + cluster_labels = HashMap::with_capacity(batch_end - input.first); + } else { + cluster_labels = HashMap::new(); + } + // Assign all points the nearest center. + // TODO(Sanket): Scope for perf improvements here. Like Paralleization, SIMD, etc. + // The actual value of previous_counts does not matter since lambda is 0. Using a vector of 0s. + let previous_counts = vec![0; input.k]; + for idx in input.first..batch_end { + let (min_center, min_distance) = + get_nearest_center(input, centers, idx, 0.0, &previous_counts)?; + cluster_counts[min_center as usize] += 1; + let point_idx = input.indices[idx]; + if min_distance <= cluster_nearest_distance[min_center as usize] { + cluster_nearest_distance[min_center as usize] = min_distance; + cluster_nearest_point_idx[min_center as usize] = point_idx as i32; } - KMeansAssignFinishOutput { - cluster_counts, - cluster_nearest_point_idx, - cluster_nearest_distance, - cluster_labels, + if generate_labels { + cluster_labels.insert(point_idx, min_center); } } + Ok(KMeansAssignFinishOutput { + cluster_counts, + cluster_nearest_point_idx, + cluster_nearest_distance, + cluster_labels, + }) +} - pub fn refine_lambda( - &self, - cluster_counts: &[usize], - cluster_weighted_counts: &[f32], - cluster_farthest_distance: &[f32], - ) -> f32 { - let batch_end = min(self.input.last, self.input.first + self.input.num_samples); - let dataset_size = batch_end - self.input.first; - let mut max_count = 0; - let mut max_cluster: i32 = -1; - // Find the cluster with the max count. - for (index, count) in cluster_counts.iter().enumerate() { - if *count > 0 && max_count < *count { - max_count = *count; - max_cluster = index as i32; - } - } - if max_cluster < 0 { - panic!("Invariant violation. There should be atleast one data point"); +fn refine_lambda( + input: &KMeansAlgorithmInput, + cluster_counts: &[usize], + cluster_weighted_counts: &[f32], + cluster_farthest_distance: &[f32], +) -> Result { + let batch_end = min(input.last, input.first + input.num_samples); + let dataset_size = batch_end - input.first; + let mut max_count = 0; + let mut max_cluster: i32 = -1; + // Find the cluster with the max count. + for (index, count) in cluster_counts.iter().enumerate() { + if *count > 0 && max_count < *count { + max_count = *count; + max_cluster = index as i32; } - let avg_distance = cluster_weighted_counts[max_cluster as usize] - / cluster_counts[max_cluster as usize] as f32; - let lambda = - (cluster_farthest_distance[max_cluster as usize] - avg_distance) / dataset_size as f32; - f32::max(0.0, lambda) } + if max_cluster < 0 { + return Err(KMeansError::MaxClusterNotFound); + } + let avg_distance = + cluster_weighted_counts[max_cluster as usize] / cluster_counts[max_cluster as usize] as f32; + let lambda = + (cluster_farthest_distance[max_cluster as usize] - avg_distance) / dataset_size as f32; + Ok(f32::max(0.0, lambda)) +} - pub fn init_centers(&self, num_iters: usize) -> (Vec>, Vec, f32) { - let batch_end = min(self.input.first + self.input.num_samples, self.input.last); - let mut min_dist = MAX_DISTANCE; - let embedding_dim = self.input.k; - let mut final_cluster_count = vec![0; embedding_dim]; - let mut final_centers = vec![vec![0.0; self.input.embedding_dimension]; embedding_dim]; - let mut lambda = 0.0; - // Randomly choose centers. - for _ in 0..num_iters { - let mut centers = vec![vec![0.0; self.input.embedding_dimension]; self.input.k]; - for center in centers.iter_mut() { - let random_center = rand::thread_rng().gen_range(self.input.first..batch_end); - center.copy_from_slice( - &self.input.embeddings[self.input.indices[random_center] as usize - * self.input.embedding_dimension - ..((self.input.indices[random_center] + 1) as usize) - * self.input.embedding_dimension], - ); - } - let kmeans_assign = self.kmeansassign_for_centerinit(¢ers); - if kmeans_assign.total_distance < min_dist { - min_dist = kmeans_assign.total_distance; - final_cluster_count = kmeans_assign.cluster_counts; - final_centers = centers; - lambda = self.refine_lambda( - &final_cluster_count, - &kmeans_assign.cluster_weighted_counts, - &kmeans_assign.cluster_farthest_distance, - ); - } +// This function initializes the centers for the kmeans algorithm. +// It runs the kmeans algorithm multiple times with random centers +// and chooses the centers that give the minimum distance. +// It also computes lambda from the chosen centers. +#[allow(clippy::type_complexity)] +fn init_centers( + input: &KMeansAlgorithmInput, + num_iters: usize, +) -> Result<(Vec>, Vec, f32), KMeansError> { + let batch_end = min(input.first + input.num_samples, input.last); + let mut min_dist = MAX_DISTANCE; + let embedding_dim = input.k; + let mut final_cluster_count = vec![0; embedding_dim]; + let mut final_centers = vec![vec![0.0; input.embedding_dimension]; embedding_dim]; + let mut lambda = 0.0; + // Randomly choose centers. + for _ in 0..num_iters { + let mut centers = vec![vec![0.0; input.embedding_dimension]; input.k]; + for center in centers.iter_mut() { + let random_center = rand::thread_rng().gen_range(input.first..batch_end); + center.copy_from_slice( + &input.embeddings[input.indices[random_center] as usize * input.embedding_dimension + ..((input.indices[random_center] + 1) as usize) * input.embedding_dimension], + ); + } + let kmeans_assign = kmeansassign_for_centerinit(input, ¢ers)?; + if kmeans_assign.total_distance < min_dist { + min_dist = kmeans_assign.total_distance; + final_cluster_count = kmeans_assign.cluster_counts; + final_centers = centers; + lambda = refine_lambda( + input, + &final_cluster_count, + &kmeans_assign.cluster_weighted_counts, + &kmeans_assign.cluster_farthest_distance, + )?; } - (final_centers, final_cluster_count, lambda) } + Ok((final_centers, final_cluster_count, lambda)) +} - fn refine_centers( - &self, - kmeansassign_output: &mut KMeansAssignForMainLoopOutput, - previous_centers: &[Vec], - ) -> f32 { - let mut max_count = 0; - let mut max_cluster_idx: i32 = -1; - #[allow(clippy::needless_range_loop)] - for cluster_idx in 0..self.input.k { - let start = kmeansassign_output.cluster_farthest_point_idx[cluster_idx] as usize - * self.input.embedding_dimension; - let end = (kmeansassign_output.cluster_farthest_point_idx[cluster_idx] + 1) as usize - * self.input.embedding_dimension; - if kmeansassign_output.cluster_counts[cluster_idx] > 0 - && kmeansassign_output.cluster_counts[cluster_idx] > max_count - && self.input.distance_function.distance( - &previous_centers[cluster_idx], - &self.input.embeddings[start..end], - ) > 1e-6 - { - max_count = kmeansassign_output.cluster_counts[cluster_idx]; - max_cluster_idx = cluster_idx as i32; - } +// This function refines the centers of the clusters. +// It calculates the new centers by averaging the embeddings of all points +// assigned to a cluster. +fn refine_centers( + input: &KMeansAlgorithmInput, + kmeansassign_output: &mut KMeansAssignForMainLoopOutput, + previous_centers: &[Vec], +) -> f32 { + let mut max_count = 0; + let mut max_cluster_idx: i32 = -1; + #[allow(clippy::needless_range_loop)] + for cluster_idx in 0..input.k { + if kmeansassign_output.cluster_counts[cluster_idx] > 0 + && kmeansassign_output.cluster_counts[cluster_idx] > max_count + && input.distance_function.distance( + &previous_centers[cluster_idx], + &input.embeddings[kmeansassign_output.cluster_farthest_point_idx[cluster_idx] + as usize + * input.embedding_dimension + ..(kmeansassign_output.cluster_farthest_point_idx[cluster_idx] + 1) as usize + * input.embedding_dimension], + ) > 1e-6 + { + max_count = kmeansassign_output.cluster_counts[cluster_idx]; + max_cluster_idx = cluster_idx as i32; } + } - // Refine centers. - let mut diff = 0.0; - #[allow(clippy::needless_range_loop)] - for cluster_idx in 0..self.input.k { - let count = kmeansassign_output.cluster_counts[cluster_idx]; - if count > 0 { - kmeansassign_output.cluster_new_centers[cluster_idx] - .iter_mut() - .for_each(|x| { - *x /= count as f32; - }); - } else if max_cluster_idx == -1 { - kmeansassign_output.cluster_new_centers[cluster_idx] - .copy_from_slice(&previous_centers[cluster_idx]); - } else { - // copy the farthest point embedding to the center. - let start = kmeansassign_output.cluster_farthest_point_idx[max_cluster_idx as usize] - as usize - * self.input.embedding_dimension; - let end = (kmeansassign_output.cluster_farthest_point_idx[max_cluster_idx as usize] - + 1) as usize - * self.input.embedding_dimension; - kmeansassign_output.cluster_new_centers[cluster_idx] - .copy_from_slice(&self.input.embeddings[start..end]); - } - diff += self.input.distance_function.distance( - &previous_centers[cluster_idx], - &kmeansassign_output.cluster_new_centers[cluster_idx], - ); + // Refine centers. + let mut diff = 0.0; + #[allow(clippy::needless_range_loop)] + for cluster_idx in 0..input.k { + let count = kmeansassign_output.cluster_counts[cluster_idx]; + if count > 0 { + kmeansassign_output.cluster_new_centers[cluster_idx] + .iter_mut() + .for_each(|x| { + *x /= count as f32; + }); + } else if max_cluster_idx == -1 { + kmeansassign_output.cluster_new_centers[cluster_idx] + .copy_from_slice(&previous_centers[cluster_idx]); + } else { + // copy the farthest point embedding to the center. + let start = kmeansassign_output.cluster_farthest_point_idx[max_cluster_idx as usize] + as usize + * input.embedding_dimension; + let end = (kmeansassign_output.cluster_farthest_point_idx[max_cluster_idx as usize] + 1) + as usize + * input.embedding_dimension; + kmeansassign_output.cluster_new_centers[cluster_idx] + .copy_from_slice(&input.embeddings[start..end]); } - diff + diff += input.distance_function.distance( + &previous_centers[cluster_idx], + &kmeansassign_output.cluster_new_centers[cluster_idx], + ); } + diff +} - pub fn cluster(&mut self) -> KMeansAlgorithmOutput { - let (initial_centers, initial_counts, adjusted_lambda) = - self.init_centers(NUM_ITERS_FOR_CENTER_INIT); - let end = min(self.input.last, self.input.first + self.input.num_samples); - let baseline_lambda = - 1.0 * 1.0 / self.input.initial_lambda / (end - self.input.first) as f32; - // Initialize. - let mut current_centers = initial_centers; - let mut current_counts = initial_counts; - let mut min_dist = MAX_DISTANCE; - let mut no_improvement = 0; - let mut previous_centers = vec![]; - #[allow(unused_assignments)] - let mut previous_counts = vec![]; - for _ in 0..NUM_ITERS_FOR_MAIN_LOOP { - // Prepare for the next iteration. - previous_centers = current_centers; - previous_counts = current_counts; - self.input.indices[self.input.first..self.input.last].shuffle(&mut rand::thread_rng()); - let mut kmeans_assign = self.kmeansassign_for_main_loop( - &previous_centers, - &previous_counts, - f32::min(adjusted_lambda, baseline_lambda), - ); - if kmeans_assign.total_distance < min_dist { - min_dist = kmeans_assign.total_distance; - no_improvement = 0; - } else { - no_improvement += 1; - } - let curr_diff = self.refine_centers(&mut kmeans_assign, &previous_centers); - // Prepare for the next iteration. - current_centers = kmeans_assign.cluster_new_centers; - current_counts = kmeans_assign.cluster_counts; - if curr_diff < 1e-3 || no_improvement >= NUM_ITERS_NO_IMPROVEMENT { - break; - } +pub fn cluster(input: &mut KMeansAlgorithmInput) -> Result { + let (initial_centers, initial_counts, adjusted_lambda) = + init_centers(input, NUM_ITERS_FOR_CENTER_INIT)?; + let end = min(input.last, input.first + input.num_samples); + let baseline_lambda = 1.0 * 1.0 / input.initial_lambda / (end - input.first) as f32; + // Initialize. + let mut current_centers = initial_centers; + let mut current_counts = initial_counts; + let mut min_dist = MAX_DISTANCE; + let mut no_improvement = 0; + let mut previous_centers = vec![]; + #[allow(unused_assignments)] + let mut previous_counts = vec![]; + for _ in 0..NUM_ITERS_FOR_MAIN_LOOP { + // Prepare for the next iteration. + previous_centers = current_centers; + previous_counts = current_counts; + input.indices[input.first..input.last].shuffle(&mut rand::thread_rng()); + let mut kmeans_assign = kmeansassign_for_main_loop( + input, + &previous_centers, + &previous_counts, + f32::min(adjusted_lambda, baseline_lambda), + )?; + if kmeans_assign.total_distance < min_dist { + min_dist = kmeans_assign.total_distance; + no_improvement = 0; + } else { + no_improvement += 1; } - // Assign points to the refined center one last time and get nearest points of each cluster. - let kmeans_assign = - self.kmeansassign_finish(&previous_centers, /* generate_labels */ false); - #[allow(clippy::needless_range_loop)] - for center_ids in 0..self.input.k { - if kmeans_assign.cluster_nearest_point_idx[center_ids] >= 0 { - let start_emb_idx = kmeans_assign.cluster_nearest_point_idx[center_ids] as usize - * self.input.embedding_dimension; - let end_emb_idx = (kmeans_assign.cluster_nearest_point_idx[center_ids] as usize - + 1) - * self.input.embedding_dimension; - previous_centers[center_ids] - .copy_from_slice(&self.input.embeddings[start_emb_idx..end_emb_idx]); - } + let curr_diff = refine_centers(input, &mut kmeans_assign, &previous_centers); + // Prepare for the next iteration. + current_centers = kmeans_assign.cluster_new_centers; + current_counts = kmeans_assign.cluster_counts; + if curr_diff < 1e-3 || no_improvement >= NUM_ITERS_NO_IMPROVEMENT { + break; } - // Finally assign points to these nearest points in the cluster. - // Previous counts does not matter since lambda is 0. - let kmeans_assign = - self.kmeansassign_finish(&previous_centers, /* generate_labels */ true); - previous_counts = kmeans_assign.cluster_counts; - - KMeansAlgorithmOutput { - cluster_centers: previous_centers, - cluster_counts: previous_counts, - cluster_labels: kmeans_assign.cluster_labels, + } + // Assign points to the refined center one last time and get nearest points of each cluster. + let kmeans_assign = + kmeansassign_finish(input, &previous_centers, /* generate_labels */ false)?; + #[allow(clippy::needless_range_loop)] + for center_ids in 0..input.k { + if kmeans_assign.cluster_nearest_point_idx[center_ids] >= 0 { + let start_emb_idx = kmeans_assign.cluster_nearest_point_idx[center_ids] as usize + * input.embedding_dimension; + let end_emb_idx = (kmeans_assign.cluster_nearest_point_idx[center_ids] as usize + 1) + * input.embedding_dimension; + previous_centers[center_ids] + .copy_from_slice(&input.embeddings[start_emb_idx..end_emb_idx]); } } + // Finally assign points to these nearest points in the cluster. + // Previous counts does not matter since lambda is 0. + let kmeans_assign = + kmeansassign_finish(input, &previous_centers, /* generate_labels */ true)?; + previous_counts = kmeans_assign.cluster_counts; + + Ok(KMeansAlgorithmOutput { + cluster_centers: previous_centers, + cluster_counts: previous_counts, + cluster_labels: kmeans_assign.cluster_labels, + }) } #[cfg(test)] mod tests { use std::collections::HashMap; - use super::KMeansAlgorithm; + use crate::spann::utils::{ + cluster, kmeansassign_finish, kmeansassign_for_centerinit, kmeansassign_for_main_loop, + KMeansAlgorithmInput, + }; #[test] fn test_kmeans_assign_for_center_init() { @@ -434,7 +510,7 @@ mod tests { 11.0, 11.0, 12.0, 12.0, ]; let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let kmeans_algo = KMeansAlgorithm::new( + let kmeans_input = KMeansAlgorithmInput::new( indices, &embeddings, dim, @@ -446,7 +522,7 @@ mod tests { 100.0, ); let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]]; - let res = kmeans_algo.kmeansassign_for_centerinit(¢ers); + let res = kmeansassign_for_centerinit(&kmeans_input, ¢ers).expect("Failed to assign"); assert_eq!(res.cluster_counts, vec![4, 4]); assert_eq!(res.total_distance, 8.0); assert_eq!(res.cluster_weighted_counts, vec![4.0, 4.0]); @@ -462,7 +538,7 @@ mod tests { 10.0, 11.0, 11.0, 11.0, 12.0, 12.0, 13.0, 13.0, ]; let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; - let kmeans_algo = KMeansAlgorithm::new( + let kmeans_input = KMeansAlgorithmInput::new( indices, &embeddings, dim, @@ -477,7 +553,8 @@ mod tests { // Penalize [10.0, 10.0] so that [5.0, 5.0] gets assigned to [0.0, 0.0] let previous_counts = vec![0, 9]; let lambda = 1.0; - let res = kmeans_algo.kmeansassign_for_main_loop(¢ers, &previous_counts, lambda); + let res = kmeansassign_for_main_loop(&kmeans_input, ¢ers, &previous_counts, lambda) + .expect("Failed to assign"); assert_eq!(res.cluster_counts, vec![5, 4]); assert_eq!(res.total_distance, 94.0); assert_eq!(res.cluster_farthest_distance, vec![50.0, 11.0]); @@ -496,7 +573,7 @@ mod tests { 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0, ]; let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let kmeans_algo = KMeansAlgorithm::new( + let kmeans_algo = KMeansAlgorithmInput::new( indices, &embeddings, dim, @@ -508,12 +585,12 @@ mod tests { 100.0, ); let centers = vec![vec![0.0, 0.0], vec![10.0, 10.0]]; - let mut res = kmeans_algo.kmeansassign_finish(¢ers, false); + let mut res = kmeansassign_finish(&kmeans_algo, ¢ers, false).expect("Failed to assign"); assert_eq!(res.cluster_counts, vec![4, 4]); assert_eq!(res.cluster_nearest_distance, vec![0.0, 0.0]); assert_eq!(res.cluster_nearest_point_idx, vec![0, 4]); assert_eq!(res.cluster_labels.len(), 0); - res = kmeans_algo.kmeansassign_finish(¢ers, true); + res = kmeansassign_finish(&kmeans_algo, ¢ers, true).expect("Failed to assign"); let mut labels = HashMap::new(); labels.insert(0, 0); labels.insert(1, 0); @@ -537,7 +614,7 @@ mod tests { 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, 11.0, 11.0, ]; let indices: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let mut kmeans_algo = KMeansAlgorithm::new( + let mut kmeans_algo = KMeansAlgorithmInput::new( indices, &embeddings, dim, @@ -548,7 +625,7 @@ mod tests { chroma_distance::DistanceFunction::Euclidean, 100.0, ); - let res = kmeans_algo.cluster(); + let res = cluster(&mut kmeans_algo).expect("Failed to cluster"); assert_eq!(res.cluster_counts, vec![4, 4]); let mut labels = HashMap::new(); labels.insert(0, res.cluster_labels[&0]);