diff --git a/rust/blockstore/src/types/writer.rs b/rust/blockstore/src/types/writer.rs index db85d60e675c..6862c17f20eb 100644 --- a/rust/blockstore/src/types/writer.rs +++ b/rust/blockstore/src/types/writer.rs @@ -80,6 +80,22 @@ impl BlockfileWriter { } } + pub async fn get_clone< + K: Key + Into + ArrowWriteableKey, + V: Value + Writeable + ArrowWriteableValue, + >( + &self, + prefix: &str, + key: K, + ) -> Result, Box> { + match self { + BlockfileWriter::MemoryBlockfileWriter(_) => todo!(), + BlockfileWriter::ArrowBlockfileWriter(writer) => { + writer.get_clone::(prefix, key).await + } + } + } + pub fn id(&self) -> uuid::Uuid { match self { BlockfileWriter::MemoryBlockfileWriter(writer) => writer.id(), diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index ce1f3165f5e8..3b57cda44534 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{atomic::AtomicU32, Arc}, +}; use arrow::error; use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions}; @@ -8,24 +11,33 @@ use chroma_types::CollectionUuid; use chroma_types::SpannPostingList; use parking_lot::RwLock; use thiserror::Error; +use tokio::sync::Mutex; use uuid::Uuid; use crate::{ hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}, - IndexUuid, + utils::normalize, + Index, IndexUuid, }; -// TODO(Sanket): Add locking structures as necessary. +struct VersionsMapInner { + versions_map: HashMap, +} + pub struct SpannIndexWriter { // HNSW index and its provider for centroid search. hnsw_index: HnswIndexRef, hnsw_provider: HnswIndexProvider, // Posting list of the centroids. - // The blockfile also contains next id for the head. - posting_list_writer: BlockfileWriter, + // TODO(Sanket): For now the lock is very coarse grained. But this should + // be change in future. + posting_list_writer: Arc>, + next_head_id: Arc, // Version number of each point. // TODO(Sanket): Finer grained locking for this map in future. - versions_map: Arc>>, + versions_map: Arc>, + distance_function: DistanceFunction, + dimensionality: usize, } #[derive(Error, Debug)] @@ -36,6 +48,18 @@ pub enum SpannIndexWriterConstructionError { BlockfileReaderConstructionError, #[error("Blockfile writer construction error")] BlockfileWriterConstructionError, + #[error("Error resizing hnsw index")] + HnswIndexResizeError, + #[error("Error adding to hnsw index")] + HnswIndexAddError, + #[error("Error searching from hnsw")] + HnswIndexSearchError, + #[error("Error adding to posting list")] + PostingListAddError, + #[error("Error searching for posting list")] + PostingListSearchError, + #[error("Expected data not found")] + ExpectedDataNotFound, } impl ChromaError for SpannIndexWriterConstructionError { @@ -44,22 +68,41 @@ impl ChromaError for SpannIndexWriterConstructionError { Self::HnswIndexConstructionError => ErrorCodes::Internal, Self::BlockfileReaderConstructionError => ErrorCodes::Internal, Self::BlockfileWriterConstructionError => ErrorCodes::Internal, + Self::HnswIndexResizeError => ErrorCodes::Internal, + Self::HnswIndexAddError => ErrorCodes::Internal, + Self::PostingListAddError => ErrorCodes::Internal, + Self::HnswIndexSearchError => ErrorCodes::Internal, + Self::PostingListSearchError => ErrorCodes::Internal, + Self::ExpectedDataNotFound => ErrorCodes::Internal, } } } +const MAX_HEAD_OFFSET_ID: &str = "max_head_offset_id"; + +// TODO(Sanket): Make this configurable. +const NUM_CENTROIDS_TO_SEARCH: u32 = 64; +const RNG_FACTOR: f32 = 1.0; +const SPLIT_THRESHOLD: usize = 100; + impl SpannIndexWriter { pub fn new( hnsw_index: HnswIndexRef, hnsw_provider: HnswIndexProvider, posting_list_writer: BlockfileWriter, - versions_map: HashMap, + next_head_id: u32, + versions_map: VersionsMapInner, + distance_function: DistanceFunction, + dimensionality: usize, ) -> Self { SpannIndexWriter { hnsw_index, hnsw_provider, - posting_list_writer, + posting_list_writer: Arc::new(Mutex::new(posting_list_writer)), + next_head_id: Arc::new(AtomicU32::new(next_head_id)), versions_map: Arc::new(RwLock::new(versions_map)), + distance_function, + dimensionality, } } @@ -105,7 +148,7 @@ impl SpannIndexWriter { async fn load_versions_map( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, - ) -> Result, SpannIndexWriterConstructionError> { + ) -> Result { // Create a reader for the blockfile. Load all the data into the versions map. let mut versions_map = HashMap::new(); let reader = match blockfile_provider.read::(blockfile_id).await { @@ -119,7 +162,7 @@ impl SpannIndexWriter { versions_data.iter().for_each(|(_, key, value)| { versions_map.insert(*key, *value); }); - Ok(versions_map) + Ok(VersionsMapInner { versions_map }) } async fn fork_postings_list( @@ -158,6 +201,7 @@ impl SpannIndexWriter { hnsw_id: Option<&IndexUuid>, versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, + max_head_id_bf_id: Option<&Uuid>, hnsw_params: Option, collection_id: &CollectionUuid, distance_function: DistanceFunction, @@ -171,7 +215,7 @@ impl SpannIndexWriter { hnsw_provider, hnsw_id, collection_id, - distance_function, + distance_function.clone(), dimensionality, ) .await? @@ -180,7 +224,7 @@ impl SpannIndexWriter { Self::create_hnsw_index( hnsw_provider, collection_id, - distance_function, + distance_function.clone(), dimensionality, hnsw_params.unwrap(), // Safe since caller should always provide this. ) @@ -192,7 +236,9 @@ impl SpannIndexWriter { Some(versions_map_id) => { Self::load_versions_map(versions_map_id, blockfile_provider).await? } - None => HashMap::new(), + None => VersionsMapInner { + versions_map: HashMap::new(), + }, }; // Fork the posting list writer. let posting_list_writer = match posting_list_id { @@ -201,18 +247,209 @@ impl SpannIndexWriter { } None => Self::create_posting_list(blockfile_provider).await?, }; + + let max_head_id = match max_head_id_bf_id { + Some(max_head_id_bf_id) => { + let reader = blockfile_provider + .open::<&str, u32>(max_head_id_bf_id) + .await; + match reader { + Ok(reader) => reader.get("", MAX_HEAD_OFFSET_ID).await.map_err(|_| { + SpannIndexWriterConstructionError::BlockfileReaderConstructionError + })?, + Err(_) => 0, + } + } + None => 0, + }; Ok(Self::new( hnsw_index, hnsw_provider.clone(), posting_list_writer, + 1 + max_head_id, versions_map, + distance_function, + dimensionality, )) } - pub fn add_versions_map(&self, id: u32) { + fn add_versions_map(&self, id: u32) -> u32 { // 0 means deleted. Version counting starts from 1. - self.versions_map.write().insert(id, 1); + let mut write_lock = self.versions_map.write(); + write_lock.versions_map.insert(id, 1); + *write_lock.versions_map.get(&id).unwrap() } - pub async fn add_new_record_to_postings_list(&self, id: u32, embeddings: &[f32]) {} + async fn rng_query( + &self, + query: &[f32], + ) -> Result<(Vec, Vec), SpannIndexWriterConstructionError> { + let mut normalized_query = query.to_vec(); + // Normalize the query in case of cosine. + if self.distance_function == DistanceFunction::Cosine { + normalized_query = normalize(query) + } + let ids; + let distances; + let mut embeddings: Vec> = vec![]; + { + let read_guard = self.hnsw_index.inner.read(); + let allowed_ids = vec![]; + let disallowed_ids = vec![]; + (ids, distances) = read_guard + .query( + &normalized_query, + NUM_CENTROIDS_TO_SEARCH as usize, + &allowed_ids, + &disallowed_ids, + ) + .map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)?; + // Get the embeddings also for distance computation. + for id in ids.iter() { + let emb = read_guard + .get(*id) + .map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)? + .ok_or(SpannIndexWriterConstructionError::HnswIndexSearchError)?; + embeddings.push(emb); + } + } + // Apply the RNG rule to prune. + let mut res_ids = vec![]; + let mut res_distances = vec![]; + let mut res_embeddings: Vec<&Vec> = vec![]; + // Embeddings that were obtained are already normalized. + for (id, (distance, embedding)) in ids.iter().zip(distances.iter().zip(embeddings.iter())) { + let mut rng_accepted = true; + for nbr_embedding in res_embeddings.iter() { + let dist = self + .distance_function + .distance(&embedding[..], &nbr_embedding[..]); + if RNG_FACTOR * dist <= *distance { + rng_accepted = false; + break; + } + } + if !rng_accepted { + continue; + } + res_ids.push(*id); + res_distances.push(*distance); + res_embeddings.push(embedding); + } + + Ok((res_ids, res_distances)) + } + + async fn append( + &self, + head_id: u32, + id: u32, + version: u32, + embedding: &[f32], + ) -> Result<(), SpannIndexWriterConstructionError> { + { + let write_guard = self.posting_list_writer.lock().await; + // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread + // deletes it. + let current_pl = write_guard + .get_clone::>("", head_id) + .await + .map_err(|_| SpannIndexWriterConstructionError::PostingListSearchError)? + .ok_or(SpannIndexWriterConstructionError::PostingListSearchError)?; + // Cleanup this posting list and append the new point to it. + // TODO(Sanket): There is an order in which we are acquiring locks here. Need + // to ensure the same order in the other places as well. + let mut updated_doc_offset_ids = vec![]; + let mut updated_versions = vec![]; + let mut updated_embeddings = vec![]; + { + let version_map_guard = self.versions_map.read(); + for (index, doc_version) in current_pl.1.iter().enumerate() { + let current_version = version_map_guard + .versions_map + .get(¤t_pl.0[index]) + .ok_or(SpannIndexWriterConstructionError::ExpectedDataNotFound)?; + // disregard if either deleted or on an older version. + if *current_version == 0 || doc_version < current_version { + continue; + } + updated_doc_offset_ids.push(current_pl.0[index]); + updated_versions.push(*doc_version); + // Slice. index*dimensionality to index*dimensionality + dimensionality + updated_embeddings.push( + ¤t_pl.2[index * self.dimensionality + ..index * self.dimensionality + self.dimensionality], + ); + } + } + // Add the new point. + updated_doc_offset_ids.push(id); + updated_versions.push(version); + updated_embeddings.push(embedding); + // TODO(Sanket): Trigger a split and reassign if the size exceeds threshold. + // Write the PL back to the blockfile and release the lock. + let posting_list = SpannPostingList { + doc_offset_ids: &updated_doc_offset_ids, + doc_versions: &updated_versions, + doc_embeddings: &updated_embeddings.concat(), + }; + // TODO(Sanket): Split if the size exceeds threshold. + write_guard + .set("", head_id, &posting_list) + .await + .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + } + Ok(()) + } + + async fn add_postings_list( + &self, + id: u32, + version: u32, + embeddings: &[f32], + ) -> Result<(), SpannIndexWriterConstructionError> { + let (ids, distances) = self.rng_query(embeddings).await?; + // Create a centroid with just this point. + if ids.is_empty() { + let next_id = self + .next_head_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + // First add to postings list then to hnsw. This order is important + // to ensure that if and when the center is discoverable, it also exists + // in the postings list. Otherwise, it will be a dangling center. + { + let posting_list = SpannPostingList { + doc_offset_ids: &[id], + doc_versions: &[version], + doc_embeddings: embeddings, + }; + let write_guard = self.posting_list_writer.lock().await; + write_guard + .set("", next_id, &posting_list) + .await + .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + } + // Next add to hnsw. + // This shouldn't exceed the capacity since this will happen only for the first few points + // so no need to check and increase the capacity. + { + let write_guard = self.hnsw_index.inner.write(); + write_guard + .add(next_id as usize, embeddings) + .map_err(|_| SpannIndexWriterConstructionError::HnswIndexAddError)?; + } + return Ok(()); + } + // Otherwise add to the posting list of these arrays. + for head_id in ids.iter() { + self.append(*head_id as u32, id, version, embeddings) + .await?; + } + + Ok(()) + } + + pub async fn add(&self, id: u32, embeddings: &[f32]) { + let version = self.add_versions_map(id); + } } diff --git a/rust/index/src/utils.rs b/rust/index/src/utils.rs index 04ffc7099bc2..0d9b3fe9542a 100644 --- a/rust/index/src/utils.rs +++ b/rust/index/src/utils.rs @@ -16,6 +16,8 @@ pub(super) fn generate_random_data(n: usize, d: usize) -> Vec { data } +const EPS: f32 = 1e-30; + pub fn merge_sorted_vecs_disjunction(a: &[T], b: &[T]) -> Vec { let mut result = Vec::with_capacity(a.len() + b.len()); let mut a_idx = 0; @@ -77,6 +79,15 @@ pub fn merge_sorted_vecs_conjunction(a: &[T], b: &[T]) -> Vec result } +pub fn normalize(vector: &[f32]) -> Vec { + let mut norm = 0.0; + for x in vector { + norm += x * x; + } + let norm = 1.0 / (norm.sqrt() + EPS); + vector.iter().map(|x| x * norm).collect() +} + #[cfg(test)] mod tests { #[test] diff --git a/rust/worker/src/execution/operators/brute_force_knn.rs b/rust/worker/src/execution/operators/brute_force_knn.rs index a549a2968dae..ae3f81ea06c0 100644 --- a/rust/worker/src/execution/operators/brute_force_knn.rs +++ b/rust/worker/src/execution/operators/brute_force_knn.rs @@ -1,5 +1,4 @@ use crate::execution::operator::Operator; -use crate::execution::operators::normalize_vectors::normalize; use crate::segment::record_segment::RecordSegmentReader; use crate::segment::LogMaterializer; use crate::segment::LogMaterializerError; @@ -7,6 +6,7 @@ use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; +use chroma_index::utils::normalize; use chroma_types::Chunk; use chroma_types::{LogRecord, MaterializedLogOperation, Segment}; use std::cmp::Ordering; diff --git a/rust/worker/src/execution/operators/normalize_vectors.rs b/rust/worker/src/execution/operators/normalize_vectors.rs index 631d787ff0c0..f40b58a18ebc 100644 --- a/rust/worker/src/execution/operators/normalize_vectors.rs +++ b/rust/worker/src/execution/operators/normalize_vectors.rs @@ -1,7 +1,6 @@ use crate::execution::operator::Operator; use async_trait::async_trait; - -const EPS: f32 = 1e-30; +use chroma_index::utils::normalize; #[derive(Debug)] pub struct NormalizeVectorOperator {} @@ -14,15 +13,6 @@ pub struct NormalizeVectorOperatorOutput { pub _normalized_vectors: Vec>, } -pub fn normalize(vector: &[f32]) -> Vec { - let mut norm = 0.0; - for x in vector { - norm += x * x; - } - let norm = 1.0 / (norm.sqrt() + EPS); - vector.iter().map(|x| x * norm).collect() -} - #[async_trait] impl Operator for NormalizeVectorOperator diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 7cc7c9558153..1cf412c1ea02 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -17,7 +17,6 @@ use crate::execution::operators::merge_knn_results::{ MergeKnnBruteForceResultInput, MergeKnnResultsOperator, MergeKnnResultsOperatorInput, MergeKnnResultsOperatorOutput, }; -use crate::execution::operators::normalize_vectors::normalize; use crate::execution::operators::pull_log::PullLogsOutput; use crate::execution::operators::record_segment_prefetch::{ Keys, OffsetIdToDataKeys, OffsetIdToUserIdKeys, RecordSegmentPrefetchIoInput, @@ -40,6 +39,7 @@ use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; +use chroma_index::utils::normalize; use chroma_index::IndexConfig; use chroma_types::{Chunk, Collection, CollectionUuid, LogRecord, Segment, VectorQueryResult}; use std::collections::HashMap; diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index e018e3e85461..faf592d10a2b 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use arrow::error; +use arrow::{compute::max, error}; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::IndexUuid; @@ -8,6 +8,7 @@ use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWri use chroma_types::SegmentUuid; use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; +use tonic::async_trait; use uuid::Uuid; use super::{ @@ -19,6 +20,7 @@ use super::{ const HNSW_PATH: &str = "hnsw_path"; const VERSION_MAP_PATH: &str = "version_map_path"; const POSTING_LIST_PATH: &str = "posting_list_path"; +const MAX_HEAD_ID_BF_PATH: &str = "max_head_id_path"; pub(crate) struct SpannSegmentWriter { index: SpannIndexWriter, @@ -39,8 +41,12 @@ pub enum SpannSegmentWriterError { VersionMapInvalidFilePath, #[error("Postings list invalid file path")] PostingListInvalidFilePath, + #[error("Max head id invalid file path")] + MaxHeadIdInvalidFilePath, #[error("Spann index creation error")] SpannIndexWriterConstructionError, + #[error("Not implemented")] + NotImplemented, } impl ChromaError for SpannSegmentWriterError { @@ -53,6 +59,8 @@ impl ChromaError for SpannSegmentWriterError { Self::VersionMapInvalidFilePath => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, Self::SpannIndexWriterConstructionError => ErrorCodes::Internal, + Self::MaxHeadIdInvalidFilePath => ErrorCodes::Internal, + Self::NotImplemented => ErrorCodes::Internal, } } } @@ -128,11 +136,30 @@ impl SpannSegmentWriter { None => None, }; + let max_head_id_bf_id = match segment.file_path.get(MAX_HEAD_ID_BF_PATH) { + Some(max_head_id_bf_path) => match max_head_id_bf_path.first() { + Some(max_head_id_bf_id) => { + let max_head_id_bf_uuid = match Uuid::parse_str(max_head_id_bf_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentWriterError::IndexIdParsingError); + } + }; + Some(max_head_id_bf_uuid) + } + None => { + return Err(SpannSegmentWriterError::MaxHeadIdInvalidFilePath); + } + }, + None => None, + }; + let index_writer = match SpannIndexWriter::from_id( hnsw_provider, hnsw_id.as_ref(), versions_map_id.as_ref(), posting_list_id.as_ref(), + max_head_id_bf_id.as_ref(), hnsw_params, &segment.collection, distance_function, @@ -154,10 +181,9 @@ impl SpannSegmentWriter { } async fn add(&self, record: &MaterializedLogRecord<'_>) { - // Initialize the record with a version. - self.index.add_new_record_to_versions_map(record.offset_id); self.index - .add_new_record_to_postings_list(record.offset_id, record.merged_embeddings()); + .add(record.offset_id, record.merged_embeddings()) + .await; } } @@ -181,6 +207,14 @@ impl<'a> SegmentWriter<'a> for SpannSegmentWriter { } async fn commit(self) -> Result> { + // TODO: Implement commit. + Ok(self) + } +} + +#[async_trait] +impl SegmentFlusher for SpannSegmentWriter { + async fn flush(self) -> Result>, Box> { todo!() } }