From f7b45536f6886c7e2d0ca2580ea373aa7f580c5a Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Fri, 8 Nov 2024 18:26:24 -0800 Subject: [PATCH] Implement append --- rust/blockstore/src/arrow/types.rs | 2 +- rust/blockstore/src/types.rs | 16 ++ rust/index/src/spann/types.rs | 272 ++++++++++++++++-- rust/index/src/utils.rs | 11 + .../execution/operators/brute_force_knn.rs | 2 +- .../execution/operators/normalize_vectors.rs | 12 +- .../src/execution/orchestration/hnsw.rs | 2 +- rust/worker/src/segment/spann_segment.rs | 42 ++- 8 files changed, 325 insertions(+), 34 deletions(-) diff --git a/rust/blockstore/src/arrow/types.rs b/rust/blockstore/src/arrow/types.rs index 1f5703436b6..c0160db394e 100644 --- a/rust/blockstore/src/arrow/types.rs +++ b/rust/blockstore/src/arrow/types.rs @@ -17,7 +17,7 @@ pub trait ArrowWriteableKey: Key + Default { ) -> BlockKeyArrowBuilder; } -pub(crate) trait ArrowWriteableValue: Value { +pub trait ArrowWriteableValue: Value { type ReadableValue<'referred_data>: ArrowReadableValue<'referred_data>; type OwnedReadableValue; diff --git a/rust/blockstore/src/types.rs b/rust/blockstore/src/types.rs index d52b30ee9d1..6f5824afd80 100644 --- a/rust/blockstore/src/types.rs +++ b/rust/blockstore/src/types.rs @@ -194,6 +194,22 @@ impl BlockfileWriter { } } + pub async fn get_clone< + K: Key + Into + ArrowWriteableKey, + V: Value + Writeable + ArrowWriteableValue, + >( + &self, + prefix: &str, + key: K, + ) -> Result, Box> { + match self { + BlockfileWriter::MemoryBlockfileWriter(_) => todo!(), + BlockfileWriter::ArrowBlockfileWriter(writer) => { + writer.get_clone::(prefix, key).await + } + } + } + pub fn id(&self) -> uuid::Uuid { match self { BlockfileWriter::MemoryBlockfileWriter(writer) => writer.id(), diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 453b3596a52..113b5326626 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::{atomic::AtomicU32, Arc}, +}; use arrow::error; use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter}; @@ -7,21 +10,33 @@ use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::SpannPostingList; use parking_lot::RwLock; use thiserror::Error; +use tokio::sync::Mutex; use uuid::Uuid; -use crate::hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}; +use crate::{ + hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}, + utils::normalize, + Index, +}; + +struct VersionsMapInner { + versions_map: HashMap, +} -// TODO(Sanket): Add locking structures as necessary. pub struct SpannIndexWriter { // HNSW index and its provider for centroid search. hnsw_index: HnswIndexRef, hnsw_provider: HnswIndexProvider, // Posting list of the centroids. - // The blockfile also contains next id for the head. - posting_list_writer: BlockfileWriter, + // TODO(Sanket): For now the lock is very coarse grained. But this should + // be change in future. + posting_list_writer: Arc>, + next_head_id: Arc, // Version number of each point. // TODO(Sanket): Finer grained locking for this map in future. - versions_map: Arc>>, + versions_map: Arc>, + distance_function: DistanceFunction, + dimensionality: usize, } #[derive(Error, Debug)] @@ -32,6 +47,18 @@ pub enum SpannIndexWriterConstructionError { BlockfileReaderConstructionError, #[error("Blockfile writer construction error")] BlockfileWriterConstructionError, + #[error("Error resizing hnsw index")] + HnswIndexResizeError, + #[error("Error adding to hnsw index")] + HnswIndexAddError, + #[error("Error searching from hnsw")] + HnswIndexSearchError, + #[error("Error adding to posting list")] + PostingListAddError, + #[error("Error searching for posting list")] + PostingListSearchError, + #[error("Expected data not found")] + ExpectedDataNotFound, } impl ChromaError for SpannIndexWriterConstructionError { @@ -40,22 +67,41 @@ impl ChromaError for SpannIndexWriterConstructionError { Self::HnswIndexConstructionError => ErrorCodes::Internal, Self::BlockfileReaderConstructionError => ErrorCodes::Internal, Self::BlockfileWriterConstructionError => ErrorCodes::Internal, + Self::HnswIndexResizeError => ErrorCodes::Internal, + Self::HnswIndexAddError => ErrorCodes::Internal, + Self::PostingListAddError => ErrorCodes::Internal, + Self::HnswIndexSearchError => ErrorCodes::Internal, + Self::PostingListSearchError => ErrorCodes::Internal, + Self::ExpectedDataNotFound => ErrorCodes::Internal, } } } +const MAX_HEAD_OFFSET_ID: &str = "max_head_offset_id"; + +// TODO(Sanket): Make this configurable. +const NUM_CENTROIDS_TO_SEARCH: u32 = 64; +const RNG_FACTOR: f32 = 1.0; +const SPLIT_THRESHOLD: usize = 100; + impl SpannIndexWriter { pub fn new( hnsw_index: HnswIndexRef, hnsw_provider: HnswIndexProvider, posting_list_writer: BlockfileWriter, - versions_map: HashMap, + next_head_id: u32, + versions_map: VersionsMapInner, + distance_function: DistanceFunction, + dimensionality: usize, ) -> Self { SpannIndexWriter { hnsw_index, hnsw_provider, - posting_list_writer, + posting_list_writer: Arc::new(Mutex::new(posting_list_writer)), + next_head_id: Arc::new(AtomicU32::new(next_head_id)), versions_map: Arc::new(RwLock::new(versions_map)), + distance_function, + dimensionality, } } @@ -101,7 +147,7 @@ impl SpannIndexWriter { async fn load_versions_map( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, - ) -> Result, SpannIndexWriterConstructionError> { + ) -> Result { // Create a reader for the blockfile. Load all the data into the versions map. let mut versions_map = HashMap::new(); let reader = match blockfile_provider.open::(blockfile_id).await { @@ -115,7 +161,7 @@ impl SpannIndexWriter { versions_data.iter().for_each(|(_, key, value)| { versions_map.insert(*key, *value); }); - Ok(versions_map) + Ok(VersionsMapInner { versions_map }) } async fn fork_postings_list( @@ -146,6 +192,7 @@ impl SpannIndexWriter { hnsw_id: Option<&Uuid>, versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, + max_head_id_bf_id: Option<&Uuid>, hnsw_params: Option, collection_id: &Uuid, distance_function: DistanceFunction, @@ -159,7 +206,7 @@ impl SpannIndexWriter { hnsw_provider, hnsw_id, collection_id, - distance_function, + distance_function.clone(), dimensionality, ) .await? @@ -168,7 +215,7 @@ impl SpannIndexWriter { Self::create_hnsw_index( hnsw_provider, collection_id, - distance_function, + distance_function.clone(), dimensionality, hnsw_params.unwrap(), // Safe since caller should always provide this. ) @@ -180,7 +227,9 @@ impl SpannIndexWriter { Some(versions_map_id) => { Self::load_versions_map(versions_map_id, blockfile_provider).await? } - None => HashMap::new(), + None => VersionsMapInner { + versions_map: HashMap::new(), + }, }; // Fork the posting list writer. let posting_list_writer = match posting_list_id { @@ -189,18 +238,209 @@ impl SpannIndexWriter { } None => Self::create_posting_list(blockfile_provider).await?, }; + + let max_head_id = match max_head_id_bf_id { + Some(max_head_id_bf_id) => { + let reader = blockfile_provider + .open::<&str, u32>(max_head_id_bf_id) + .await; + match reader { + Ok(reader) => reader.get("", MAX_HEAD_OFFSET_ID).await.map_err(|_| { + SpannIndexWriterConstructionError::BlockfileReaderConstructionError + })?, + Err(_) => 0, + } + } + None => 0, + }; Ok(Self::new( hnsw_index, hnsw_provider.clone(), posting_list_writer, + 1 + max_head_id, versions_map, + distance_function, + dimensionality, )) } - pub fn add_versions_map(&self, id: u32) { + fn add_versions_map(&self, id: u32) -> u32 { // 0 means deleted. Version counting starts from 1. - self.versions_map.write().insert(id, 1); + let mut write_lock = self.versions_map.write(); + write_lock.versions_map.insert(id, 1); + *write_lock.versions_map.get(&id).unwrap() + } + + async fn rng_query( + &self, + query: &[f32], + ) -> Result<(Vec, Vec), SpannIndexWriterConstructionError> { + let mut normalized_query = query.to_vec(); + // Normalize the query in case of cosine. + if self.distance_function == DistanceFunction::Cosine { + normalized_query = normalize(query) + } + let ids; + let distances; + let mut embeddings: Vec> = vec![]; + { + let read_guard = self.hnsw_index.inner.read(); + let allowed_ids = vec![]; + let disallowed_ids = vec![]; + (ids, distances) = read_guard + .query( + &normalized_query, + NUM_CENTROIDS_TO_SEARCH as usize, + &allowed_ids, + &disallowed_ids, + ) + .map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)?; + // Get the embeddings also for distance computation. + for id in ids.iter() { + let emb = read_guard + .get(*id) + .map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)? + .ok_or(SpannIndexWriterConstructionError::HnswIndexSearchError)?; + embeddings.push(emb); + } + } + // Apply the RNG rule to prune. + let mut res_ids = vec![]; + let mut res_distances = vec![]; + let mut res_embeddings: Vec<&Vec> = vec![]; + // Embeddings that were obtained are already normalized. + for (id, (distance, embedding)) in ids.iter().zip(distances.iter().zip(embeddings.iter())) { + let mut rng_accepted = true; + for nbr_embedding in res_embeddings.iter() { + let dist = self + .distance_function + .distance(&embedding[..], &nbr_embedding[..]); + if RNG_FACTOR * dist <= *distance { + rng_accepted = false; + break; + } + } + if !rng_accepted { + continue; + } + res_ids.push(*id); + res_distances.push(*distance); + res_embeddings.push(embedding); + } + + Ok((res_ids, res_distances)) + } + + async fn append( + &self, + head_id: u32, + id: u32, + version: u32, + embedding: &[f32], + ) -> Result<(), SpannIndexWriterConstructionError> { + { + let write_guard = self.posting_list_writer.lock().await; + // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread + // deletes it. + let current_pl = write_guard + .get_clone::>("", head_id) + .await + .map_err(|_| SpannIndexWriterConstructionError::PostingListSearchError)? + .ok_or(SpannIndexWriterConstructionError::PostingListSearchError)?; + // Cleanup this posting list and append the new point to it. + // TODO(Sanket): There is an order in which we are acquiring locks here. Need + // to ensure the same order in the other places as well. + let mut updated_doc_offset_ids = vec![]; + let mut updated_versions = vec![]; + let mut updated_embeddings = vec![]; + { + let version_map_guard = self.versions_map.read(); + for (index, doc_version) in current_pl.1.iter().enumerate() { + let current_version = version_map_guard + .versions_map + .get(¤t_pl.0[index]) + .ok_or(SpannIndexWriterConstructionError::ExpectedDataNotFound)?; + // disregard if either deleted or on an older version. + if *current_version == 0 || doc_version < current_version { + continue; + } + updated_doc_offset_ids.push(current_pl.0[index]); + updated_versions.push(*doc_version); + // Slice. index*dimensionality to index*dimensionality + dimensionality + updated_embeddings.push( + ¤t_pl.2[index * self.dimensionality + ..index * self.dimensionality + self.dimensionality], + ); + } + } + // Add the new point. + updated_doc_offset_ids.push(id); + updated_versions.push(version); + updated_embeddings.push(embedding); + // TODO(Sanket): Trigger a split and reassign if the size exceeds threshold. + // Write the PL back to the blockfile and release the lock. + let posting_list = SpannPostingList { + doc_offset_ids: &updated_doc_offset_ids, + doc_versions: &updated_versions, + doc_embeddings: &updated_embeddings.concat(), + }; + // TODO(Sanket): Split if the size exceeds threshold. + write_guard + .set("", head_id, &posting_list) + .await + .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + } + Ok(()) } - pub async fn add_new_record_to_postings_list(&self, id: u32, embeddings: &[f32]) {} + async fn add_postings_list( + &self, + id: u32, + version: u32, + embeddings: &[f32], + ) -> Result<(), SpannIndexWriterConstructionError> { + let (ids, distances) = self.rng_query(embeddings).await?; + // Create a centroid with just this point. + if ids.is_empty() { + let next_id = self + .next_head_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + // First add to postings list then to hnsw. This order is important + // to ensure that if and when the center is discoverable, it also exists + // in the postings list. Otherwise, it will be a dangling center. + { + let posting_list = SpannPostingList { + doc_offset_ids: &[id], + doc_versions: &[version], + doc_embeddings: embeddings, + }; + let write_guard = self.posting_list_writer.lock().await; + write_guard + .set("", next_id, &posting_list) + .await + .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + } + // Next add to hnsw. + // This shouldn't exceed the capacity since this will happen only for the first few points + // so no need to check and increase the capacity. + { + let write_guard = self.hnsw_index.inner.write(); + write_guard + .add(next_id as usize, embeddings) + .map_err(|_| SpannIndexWriterConstructionError::HnswIndexAddError)?; + } + return Ok(()); + } + // Otherwise add to the posting list of these arrays. + for head_id in ids.iter() { + self.append(*head_id as u32, id, version, embeddings) + .await?; + } + + Ok(()) + } + + pub async fn add(&self, id: u32, embeddings: &[f32]) { + let version = self.add_versions_map(id); + } } diff --git a/rust/index/src/utils.rs b/rust/index/src/utils.rs index 04ffc7099bc..0d9b3fe9542 100644 --- a/rust/index/src/utils.rs +++ b/rust/index/src/utils.rs @@ -16,6 +16,8 @@ pub(super) fn generate_random_data(n: usize, d: usize) -> Vec { data } +const EPS: f32 = 1e-30; + pub fn merge_sorted_vecs_disjunction(a: &[T], b: &[T]) -> Vec { let mut result = Vec::with_capacity(a.len() + b.len()); let mut a_idx = 0; @@ -77,6 +79,15 @@ pub fn merge_sorted_vecs_conjunction(a: &[T], b: &[T]) -> Vec result } +pub fn normalize(vector: &[f32]) -> Vec { + let mut norm = 0.0; + for x in vector { + norm += x * x; + } + let norm = 1.0 / (norm.sqrt() + EPS); + vector.iter().map(|x| x * norm).collect() +} + #[cfg(test)] mod tests { #[test] diff --git a/rust/worker/src/execution/operators/brute_force_knn.rs b/rust/worker/src/execution/operators/brute_force_knn.rs index 3fc855e8ec9..db233c50767 100644 --- a/rust/worker/src/execution/operators/brute_force_knn.rs +++ b/rust/worker/src/execution/operators/brute_force_knn.rs @@ -1,5 +1,4 @@ use crate::execution::operator::Operator; -use crate::execution::operators::normalize_vectors::normalize; use crate::segment::record_segment::RecordSegmentReader; use crate::segment::LogMaterializer; use crate::segment::LogMaterializerError; @@ -7,6 +6,7 @@ use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; +use chroma_index::utils::normalize; use chroma_types::Chunk; use chroma_types::{LogRecord, MaterializedLogOperation, Segment}; use std::cmp::Ordering; diff --git a/rust/worker/src/execution/operators/normalize_vectors.rs b/rust/worker/src/execution/operators/normalize_vectors.rs index 631d787ff0c..f40b58a18eb 100644 --- a/rust/worker/src/execution/operators/normalize_vectors.rs +++ b/rust/worker/src/execution/operators/normalize_vectors.rs @@ -1,7 +1,6 @@ use crate::execution::operator::Operator; use async_trait::async_trait; - -const EPS: f32 = 1e-30; +use chroma_index::utils::normalize; #[derive(Debug)] pub struct NormalizeVectorOperator {} @@ -14,15 +13,6 @@ pub struct NormalizeVectorOperatorOutput { pub _normalized_vectors: Vec>, } -pub fn normalize(vector: &[f32]) -> Vec { - let mut norm = 0.0; - for x in vector { - norm += x * x; - } - let norm = 1.0 / (norm.sqrt() + EPS); - vector.iter().map(|x| x * norm).collect() -} - #[async_trait] impl Operator for NormalizeVectorOperator diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 2371fb28266..e677061425b 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -17,7 +17,6 @@ use crate::execution::operators::merge_knn_results::{ MergeKnnBruteForceResultInput, MergeKnnResultsOperator, MergeKnnResultsOperatorInput, MergeKnnResultsOperatorOutput, }; -use crate::execution::operators::normalize_vectors::normalize; use crate::execution::operators::pull_log::PullLogsOutput; use crate::execution::operators::record_segment_prefetch::{ Keys, OffsetIdToDataKeys, OffsetIdToUserIdKeys, RecordSegmentPrefetchIoInput, @@ -40,6 +39,7 @@ use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; +use chroma_index::utils::normalize; use chroma_index::IndexConfig; use chroma_types::{Chunk, Collection, LogRecord, Segment, VectorQueryResult}; use std::collections::HashMap; diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index d44a3261d28..156eee62363 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; -use arrow::error; +use arrow::{compute::max, error}; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; +use tonic::async_trait; use uuid::Uuid; use super::{ @@ -17,6 +18,7 @@ use super::{ const HNSW_PATH: &str = "hnsw_path"; const VERSION_MAP_PATH: &str = "version_map_path"; const POSTING_LIST_PATH: &str = "posting_list_path"; +const MAX_HEAD_ID_BF_PATH: &str = "max_head_id_path"; pub(crate) struct SpannSegmentWriter { index: SpannIndexWriter, @@ -37,8 +39,12 @@ pub enum SpannSegmentWriterError { VersionMapInvalidFilePath, #[error("Postings list invalid file path")] PostingListInvalidFilePath, + #[error("Max head id invalid file path")] + MaxHeadIdInvalidFilePath, #[error("Spann index creation error")] SpannIndexWriterConstructionError, + #[error("Not implemented")] + NotImplemented, } impl ChromaError for SpannSegmentWriterError { @@ -51,6 +57,8 @@ impl ChromaError for SpannSegmentWriterError { Self::VersionMapInvalidFilePath => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, Self::SpannIndexWriterConstructionError => ErrorCodes::Internal, + Self::MaxHeadIdInvalidFilePath => ErrorCodes::Internal, + Self::NotImplemented => ErrorCodes::Internal, } } } @@ -123,11 +131,30 @@ impl SpannSegmentWriter { None => None, }; + let max_head_id_bf_id = match segment.file_path.get(MAX_HEAD_ID_BF_PATH) { + Some(max_head_id_bf_path) => match max_head_id_bf_path.first() { + Some(max_head_id_bf_id) => { + let max_head_id_bf_uuid = match Uuid::parse_str(max_head_id_bf_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentWriterError::IndexIdParsingError); + } + }; + Some(max_head_id_bf_uuid) + } + None => { + return Err(SpannSegmentWriterError::MaxHeadIdInvalidFilePath); + } + }, + None => None, + }; + let index_writer = match SpannIndexWriter::from_id( hnsw_provider, hnsw_id.as_ref(), versions_map_id.as_ref(), posting_list_id.as_ref(), + max_head_id_bf_id.as_ref(), hnsw_params, &segment.collection, distance_function, @@ -149,10 +176,9 @@ impl SpannSegmentWriter { } async fn add(&self, record: &MaterializedLogRecord<'_>) { - // Initialize the record with a version. - self.index.add_new_record_to_versions_map(record.offset_id); self.index - .add_new_record_to_postings_list(record.offset_id, record.merged_embeddings()); + .add(record.offset_id, record.merged_embeddings()) + .await; } } @@ -176,6 +202,14 @@ impl<'a> SegmentWriter<'a> for SpannSegmentWriter { } async fn commit(self) -> Result> { + // TODO: Implement commit. + Ok(self) + } +} + +#[async_trait] +impl SegmentFlusher for SpannSegmentWriter { + async fn flush(self) -> Result>, Box> { todo!() } }