diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 9e477d53b2e4..92532b8c27d6 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -19,10 +19,11 @@ use crate::{ hnsw_provider::{ HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexRef, }, + spann::utils::cluster, Index, IndexUuid, }; -use super::utils::{cluster, KMeansAlgorithmInput, KMeansError}; +use super::utils::{rng_query, KMeansAlgorithmInput, KMeansError}; pub struct VersionsMapInner { pub versions_map: HashMap, @@ -363,68 +364,16 @@ impl SpannIndexWriter { &self, query: &[f32], ) -> Result<(Vec, Vec, Vec>), SpannIndexWriterError> { - let mut nearby_ids: Vec = vec![]; - let mut nearby_distances: Vec = vec![]; - let mut embeddings: Vec> = vec![]; - { - let read_guard = self.hnsw_index.inner.read(); - let allowed_ids = vec![]; - let disallowed_ids = vec![]; - let (ids, distances) = read_guard - .query( - query, - NUM_CENTROIDS_TO_SEARCH as usize, - &allowed_ids, - &disallowed_ids, - ) - .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)?; - // Get the embeddings also for distance computation. - // Normalization is idempotent and since we write normalized embeddings - // to the hnsw index, we'll get the same embeddings after denormalization. - for (id, distance) in ids.iter().zip(distances.iter()) { - if *distance <= (1_f32 + QUERY_EPSILON) * distances[0] { - nearby_ids.push(*id); - nearby_distances.push(*distance); - } - } - // Get the embeddings also for distance computation. - for id in nearby_ids.iter() { - let emb = read_guard - .get(*id) - .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)? - .ok_or(SpannIndexWriterError::HnswIndexSearchError)?; - embeddings.push(emb); - } - } - // Apply the RNG rule to prune. - let mut res_ids = vec![]; - let mut res_distances = vec![]; - let mut res_embeddings: Vec> = vec![]; - // Embeddings that were obtained are already normalized. - for (id, (distance, embedding)) in nearby_ids - .iter() - .zip(nearby_distances.iter().zip(embeddings)) - { - let mut rng_accepted = true; - for nbr_embedding in res_embeddings.iter() { - // Embeddings are already normalized so no need to normalize again. - let dist = self - .distance_function - .distance(&embedding[..], &nbr_embedding[..]); - if RNG_FACTOR * dist <= *distance { - rng_accepted = false; - break; - } - } - if !rng_accepted { - continue; - } - res_ids.push(*id); - res_distances.push(*distance); - res_embeddings.push(embedding); - } - - Ok((res_ids, res_distances, res_embeddings)) + rng_query( + query, + self.hnsw_index.clone(), + NUM_CENTROIDS_TO_SEARCH as usize, + QUERY_EPSILON, + RNG_FACTOR, + self.distance_function.clone(), + ) + .await + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError) } async fn is_outdated( diff --git a/rust/index/src/spann/utils.rs b/rust/index/src/spann/utils.rs index 821f97117a95..0b0822e82b73 100644 --- a/rust/index/src/spann/utils.rs +++ b/rust/index/src/spann/utils.rs @@ -5,6 +5,8 @@ use chroma_error::{ChromaError, ErrorCodes}; use rand::{seq::SliceRandom, Rng}; use thiserror::Error; +use crate::{hnsw_provider::HnswIndexRef, Index}; + // TODO(Sanket): I don't understand why the reference implementation defined // max_distance this way. // TODO(Sanket): Make these configurable. @@ -501,6 +503,82 @@ pub fn cluster(input: &mut KMeansAlgorithmInput) -> Result ErrorCodes { + match self { + Self::HnswSearchError => ErrorCodes::Internal, + } + } +} + +// Assumes that query is already normalized. +pub async fn rng_query( + normalized_query: &[f32], + hnsw_index: HnswIndexRef, + k: usize, + rng_epsilon: f32, + rng_factor: f32, + distance_function: DistanceFunction, +) -> Result<(Vec, Vec, Vec>), RngQueryError> { + let mut nearby_ids: Vec = vec![]; + let mut nearby_distances: Vec = vec![]; + let mut embeddings: Vec> = vec![]; + { + let read_guard = hnsw_index.inner.read(); + let allowed_ids = vec![]; + let disallowed_ids = vec![]; + let (ids, distances) = read_guard + .query(normalized_query, k, &allowed_ids, &disallowed_ids) + .map_err(|_| RngQueryError::HnswSearchError)?; + for (id, distance) in ids.iter().zip(distances.iter()) { + if *distance <= (1_f32 + rng_epsilon) * distances[0] { + nearby_ids.push(*id); + nearby_distances.push(*distance); + } + } + // Get the embeddings also for distance computation. + for id in nearby_ids.iter() { + let emb = read_guard + .get(*id) + .map_err(|_| RngQueryError::HnswSearchError)? + .ok_or(RngQueryError::HnswSearchError)?; + embeddings.push(emb); + } + } + // Apply the RNG rule to prune. + let mut res_ids = vec![]; + let mut res_distances = vec![]; + let mut res_embeddings: Vec> = vec![]; + // Embeddings that were obtained are already normalized. + for (id, (distance, embedding)) in nearby_ids + .iter() + .zip(nearby_distances.iter().zip(embeddings)) + { + let mut rng_accepted = true; + for nbr_embedding in res_embeddings.iter() { + let dist = distance_function.distance(&embedding[..], &nbr_embedding[..]); + if rng_factor * dist <= *distance { + rng_accepted = false; + break; + } + } + if !rng_accepted { + continue; + } + res_ids.push(*id); + res_distances.push(*distance); + res_embeddings.push(embedding); + } + + Ok((res_ids, res_distances, res_embeddings)) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index cf781624f651..930b3a630017 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -9,6 +9,7 @@ pub(super) mod partition; pub(super) mod pull_log; pub(super) mod record_segment_prefetch; pub(super) mod register; +pub(super) mod spann_centers_search; pub(super) mod write_segments; // Required for benchmark diff --git a/rust/worker/src/execution/operators/spann_centers_search.rs b/rust/worker/src/execution/operators/spann_centers_search.rs new file mode 100644 index 000000000000..778464744a18 --- /dev/null +++ b/rust/worker/src/execution/operators/spann_centers_search.rs @@ -0,0 +1,83 @@ +use chroma_distance::DistanceFunction; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_index::spann::utils::rng_query; +use thiserror::Error; +use tonic::async_trait; + +use crate::{ + execution::operator::Operator, + segment::spann_segment::{SpannSegmentReader, SpannSegmentReaderContext}, +}; + +#[derive(Debug)] +pub struct SpannCentersSearchInput { + reader_context: SpannSegmentReaderContext, + // Assumes that query is already normalized in case of cosine. + query: Vec, + k: usize, + rng_epsilon: f32, + rng_factor: f32, + distance_function: DistanceFunction, +} + +#[derive(Debug)] +pub struct SpannCentersSearchOutput { + center_ids: Vec, +} + +#[derive(Error, Debug)] +pub enum SpannCentersSearchError { + #[error("Error creating spann segment reader")] + SpannSegmentReaderCreationError, + #[error("Error querying RNG")] + RngQueryError, +} + +impl ChromaError for SpannCentersSearchError { + fn code(&self) -> ErrorCodes { + match self { + Self::SpannSegmentReaderCreationError => ErrorCodes::Internal, + Self::RngQueryError => ErrorCodes::Internal, + } + } +} + +#[derive(Debug)] +pub struct SpannCentersSearchOperator {} + +impl SpannCentersSearchOperator { + pub fn new() -> Box { + Box::new(SpannCentersSearchOperator {}) + } +} + +#[async_trait] +impl Operator for SpannCentersSearchOperator { + type Error = SpannCentersSearchError; + + async fn run( + &self, + input: &SpannCentersSearchInput, + ) -> Result { + let spann_reader = SpannSegmentReader::from_segment( + &input.reader_context.segment, + &input.reader_context.blockfile_provider, + &input.reader_context.hnsw_provider, + input.reader_context.dimension, + ) + .await + .map_err(|_| SpannCentersSearchError::SpannSegmentReaderCreationError)?; + // RNG Query. + let res = rng_query( + &input.query, + spann_reader.index_reader.hnsw_index.clone(), + input.k, + input.rng_epsilon, + input.rng_factor, + input.distance_function.clone(), + ) + .await + .map_err(|_| SpannCentersSearchError::RngQueryError)?; + Ok(SpannCentersSearchOutput { center_ids: res.0 }) + } +} diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 7b0739494864..472235d414c1 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -344,10 +344,18 @@ impl ChromaError for SpannSegmentReaderError { } } +#[derive(Debug)] +pub struct SpannSegmentReaderContext { + pub segment: Segment, + pub blockfile_provider: BlockfileProvider, + pub hnsw_provider: HnswIndexProvider, + pub dimension: usize, +} + #[derive(Clone)] #[allow(dead_code)] pub(crate) struct SpannSegmentReader<'me> { - index_reader: SpannIndexReader<'me>, + pub index_reader: SpannIndexReader<'me>, id: SegmentUuid, }