diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 1b16e1c73a7..1810ee98190 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -4,20 +4,20 @@ use std::{ }; use chroma_blockstore::{ - provider::BlockfileProvider, BlockfileFlusher, BlockfileWriter, BlockfileWriterOptions, + provider::{BlockfileProvider, CreateError, OpenError}, + BlockfileFlusher, BlockfileWriter, BlockfileWriterOptions, }; -use chroma_distance::DistanceFunction; +use chroma_distance::{normalize, DistanceFunction}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::CollectionUuid; use chroma_types::SpannPostingList; -use parking_lot::RwLock; use thiserror::Error; -use tokio::sync::Mutex; use uuid::Uuid; use crate::{ - hnsw_provider::{HnswIndexProvider, HnswIndexRef}, - utils::normalize, + hnsw_provider::{ + HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexRef, + }, Index, IndexUuid, }; @@ -35,52 +35,95 @@ pub struct SpannIndexWriter { // Posting list of the centroids. // TODO(Sanket): For now the lock is very coarse grained. But this should // be changed in future if perf is not satisfactory. - pub posting_list_writer: Arc>, + pub posting_list_writer: Arc>, pub next_head_id: Arc, // Version number of each point. // TODO(Sanket): Finer grained locking for this map in future if perf is not satisfactory. - pub versions_map: Arc>, + pub versions_map: Arc>, pub distance_function: DistanceFunction, pub dimensionality: usize, } +// TODO(Sanket): Can compose errors whenever downstream returns Box. #[derive(Error, Debug)] -pub enum SpannIndexWriterConstructionError { - #[error("Error creating/forking hnsw index")] - HnswIndexConstructionError, - #[error("Error creating blockfile reader")] - BlockfileReaderConstructionError, - #[error("Error creating/forking blockfile writer")] - BlockfileWriterConstructionError, - #[error("Error loading version data from blockfile")] - BlockfileVersionDataLoadError, +pub enum SpannIndexWriterError { + #[error("Error forking hnsw index {0}")] + HnswIndexForkError(#[from] HnswIndexProviderForkError), + #[error("Error creating hnsw index {0}")] + HnswIndexCreateError(#[from] HnswIndexProviderCreateError), + #[error("Error opening reader for versions map blockfile {0}")] + VersionsMapOpenError(#[from] OpenError), + #[error("Error creating/forking postings list writer {0}")] + PostingsListCreateError(#[from] CreateError), + #[error("Error loading version data from blockfile {0}")] + VersionsMapDataLoadError(#[from] Box), + #[error("Error reading max offset id for heads")] + MaxHeadOffsetIdBlockfileGetError, #[error("Error resizing hnsw index")] HnswIndexResizeError, #[error("Error adding to hnsw index")] HnswIndexAddError, #[error("Error searching from hnsw")] HnswIndexSearchError, - #[error("Error adding to posting list")] - PostingListAddError, - #[error("Error searching for posting list")] - PostingListSearchError, - #[error("Expected data not found")] - ExpectedDataNotFound, + #[error("Error adding posting list for a head")] + PostingListSetError, + #[error("Error getting the posting list for a head")] + PostingListGetError, + #[error("Did not find the version for head id")] + VersionNotFound, + #[error("Error committing postings list blockfile")] + PostingListCommitError, + #[error("Error creating blockfile writer for versions map")] + VersionsMapWriterCreateError, + #[error("Error writing data to versions map blockfile")] + VersionsMapSetError, + #[error("Error committing versions map blockfile")] + VersionsMapCommitError, + #[error("Error creating blockfile writer for max head id")] + MaxHeadIdWriterCreateError, + #[error("Error writing data to max head id blockfile")] + MaxHeadIdSetError, + #[error("Error committing max head id blockfile")] + MaxHeadIdCommitError, + #[error("Error committing hnsw index")] + HnswIndexCommitError, + #[error("Error flushing postings list blockfile")] + PostingListFlushError, + #[error("Error flushing versions map blockfile")] + VersionsMapFlushError, + #[error("Error flushing max head id blockfile")] + MaxHeadIdFlushError, + #[error("Error flushing hnsw index")] + HnswIndexFlushError, } -impl ChromaError for SpannIndexWriterConstructionError { +impl ChromaError for SpannIndexWriterError { fn code(&self) -> ErrorCodes { match self { - Self::HnswIndexConstructionError => ErrorCodes::Internal, - Self::BlockfileReaderConstructionError => ErrorCodes::Internal, - Self::BlockfileWriterConstructionError => ErrorCodes::Internal, - Self::BlockfileVersionDataLoadError => ErrorCodes::Internal, + Self::HnswIndexForkError(e) => e.code(), + Self::HnswIndexCreateError(e) => e.code(), + Self::VersionsMapOpenError(e) => e.code(), + Self::PostingsListCreateError(e) => e.code(), + Self::VersionsMapDataLoadError(e) => e.code(), + Self::MaxHeadOffsetIdBlockfileGetError => ErrorCodes::Internal, Self::HnswIndexResizeError => ErrorCodes::Internal, Self::HnswIndexAddError => ErrorCodes::Internal, - Self::PostingListAddError => ErrorCodes::Internal, + Self::PostingListSetError => ErrorCodes::Internal, Self::HnswIndexSearchError => ErrorCodes::Internal, - Self::PostingListSearchError => ErrorCodes::Internal, - Self::ExpectedDataNotFound => ErrorCodes::Internal, + Self::PostingListGetError => ErrorCodes::Internal, + Self::VersionNotFound => ErrorCodes::Internal, + Self::PostingListCommitError => ErrorCodes::Internal, + Self::VersionsMapSetError => ErrorCodes::Internal, + Self::VersionsMapCommitError => ErrorCodes::Internal, + Self::MaxHeadIdSetError => ErrorCodes::Internal, + Self::MaxHeadIdCommitError => ErrorCodes::Internal, + Self::HnswIndexCommitError => ErrorCodes::Internal, + Self::PostingListFlushError => ErrorCodes::Internal, + Self::VersionsMapFlushError => ErrorCodes::Internal, + Self::MaxHeadIdFlushError => ErrorCodes::Internal, + Self::HnswIndexFlushError => ErrorCodes::Internal, + Self::VersionsMapWriterCreateError => ErrorCodes::Internal, + Self::MaxHeadIdWriterCreateError => ErrorCodes::Internal, } } } @@ -111,9 +154,9 @@ impl SpannIndexWriter { hnsw_index, hnsw_provider, blockfile_provider, - posting_list_writer: Arc::new(Mutex::new(posting_list_writer)), + posting_list_writer: Arc::new(tokio::sync::Mutex::new(posting_list_writer)), next_head_id: Arc::new(AtomicU32::new(next_head_id)), - versions_map: Arc::new(RwLock::new(versions_map)), + versions_map: Arc::new(parking_lot::RwLock::new(versions_map)), distance_function, dimensionality, } @@ -125,13 +168,13 @@ impl SpannIndexWriter { collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, - ) -> Result { + ) -> Result { match hnsw_provider .fork(id, collection_id, dimensionality as i32, distance_function) .await { Ok(index) => Ok(index), - Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError), + Err(e) => Err(SpannIndexWriterError::HnswIndexForkError(*e)), } } @@ -143,7 +186,7 @@ impl SpannIndexWriter { m: usize, ef_construction: usize, ef_search: usize, - ) -> Result { + ) -> Result { match hnsw_provider .create( collection_id, @@ -156,27 +199,25 @@ impl SpannIndexWriter { .await { Ok(index) => Ok(index), - Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError), + Err(e) => Err(SpannIndexWriterError::HnswIndexCreateError(*e)), } } async fn load_versions_map( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { // Create a reader for the blockfile. Load all the data into the versions map. let mut versions_map = HashMap::new(); let reader = match blockfile_provider.read::(blockfile_id).await { Ok(reader) => reader, - Err(_) => { - return Err(SpannIndexWriterConstructionError::BlockfileReaderConstructionError) - } + Err(e) => return Err(SpannIndexWriterError::VersionsMapOpenError(*e)), }; // Load data using the reader. let versions_data = reader .get_range(.., ..) .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileVersionDataLoadError)?; + .map_err(SpannIndexWriterError::VersionsMapDataLoadError)?; versions_data.iter().for_each(|(key, value)| { versions_map.insert(*key, *value); }); @@ -186,7 +227,7 @@ impl SpannIndexWriter { async fn fork_postings_list( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); bf_options = bf_options.fork(*blockfile_id); @@ -195,13 +236,13 @@ impl SpannIndexWriter { .await { Ok(writer) => Ok(writer), - Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), + Err(e) => Err(SpannIndexWriterError::PostingsListCreateError(*e)), } } async fn create_posting_list( blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); match blockfile_provider @@ -209,7 +250,7 @@ impl SpannIndexWriter { .await { Ok(writer) => Ok(writer), - Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), + Err(e) => Err(SpannIndexWriterError::PostingsListCreateError(*e)), } } @@ -227,7 +268,7 @@ impl SpannIndexWriter { distance_function: DistanceFunction, dimensionality: usize, blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { // Create the HNSW index. let hnsw_index = match hnsw_id { Some(hnsw_id) => { @@ -279,9 +320,7 @@ impl SpannIndexWriter { Ok(reader) => reader .get("", MAX_HEAD_OFFSET_ID) .await - .map_err(|_| { - SpannIndexWriterConstructionError::BlockfileReaderConstructionError - })? + .map_err(|_| SpannIndexWriterError::MaxHeadOffsetIdBlockfileGetError)? .unwrap(), Err(_) => 1, } @@ -311,12 +350,7 @@ impl SpannIndexWriter { async fn rng_query( &self, query: &[f32], - ) -> Result<(Vec, Vec), SpannIndexWriterConstructionError> { - let mut normalized_query = query.to_vec(); - // Normalize the query in case of cosine. - if self.distance_function == DistanceFunction::Cosine { - normalized_query = normalize(query) - } + ) -> Result<(Vec, Vec), SpannIndexWriterError> { let ids; let distances; let mut embeddings: Vec> = vec![]; @@ -324,20 +358,23 @@ impl SpannIndexWriter { let read_guard = self.hnsw_index.inner.read(); let allowed_ids = vec![]; let disallowed_ids = vec![]; + // Query is already normalized so no need to normalize again. (ids, distances) = read_guard .query( - &normalized_query, + query, NUM_CENTROIDS_TO_SEARCH as usize, &allowed_ids, &disallowed_ids, ) - .map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)?; + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)?; // Get the embeddings also for distance computation. + // Normalization is idempotent and since we write normalized embeddings + // to the hnsw index, we'll get the same embeddings after denormalization. for id in ids.iter() { let emb = read_guard .get(*id) - .map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)? - .ok_or(SpannIndexWriterConstructionError::HnswIndexSearchError)?; + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)? + .ok_or(SpannIndexWriterError::HnswIndexSearchError)?; embeddings.push(emb); } } @@ -345,10 +382,10 @@ impl SpannIndexWriter { let mut res_ids = vec![]; let mut res_distances = vec![]; let mut res_embeddings: Vec<&Vec> = vec![]; - // Embeddings that were obtained are already normalized. for (id, (distance, embedding)) in ids.iter().zip(distances.iter().zip(embeddings.iter())) { let mut rng_accepted = true; for nbr_embedding in res_embeddings.iter() { + // Embeddings are already normalized so no need to normalize again. let dist = self .distance_function .distance(&embedding[..], &nbr_embedding[..]); @@ -375,7 +412,7 @@ impl SpannIndexWriter { id: u32, version: u32, embedding: &[f32], - ) -> Result<(), SpannIndexWriterConstructionError> { + ) -> Result<(), SpannIndexWriterError> { { let write_guard = self.posting_list_writer.lock().await; // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread @@ -383,8 +420,8 @@ impl SpannIndexWriter { let current_pl = write_guard .get_owned::>("", head_id) .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListSearchError)? - .ok_or(SpannIndexWriterConstructionError::PostingListSearchError)?; + .map_err(|_| SpannIndexWriterError::PostingListGetError)? + .ok_or(SpannIndexWriterError::PostingListGetError)?; // Cleanup this posting list and append the new point to it. // TODO(Sanket): There is an order in which we are acquiring locks here. Need // to ensure the same order in the other places as well. @@ -397,7 +434,7 @@ impl SpannIndexWriter { let current_version = version_map_guard .versions_map .get(¤t_pl.0[index]) - .ok_or(SpannIndexWriterConstructionError::ExpectedDataNotFound)?; + .ok_or(SpannIndexWriterError::VersionNotFound)?; // disregard if either deleted or on an older version. if *current_version == 0 || doc_version < current_version { continue; @@ -426,20 +463,25 @@ impl SpannIndexWriter { write_guard .set("", head_id, &posting_list) .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; } Ok(()) } #[allow(dead_code)] - async fn add_postings_list( + async fn add_to_postings_list( &self, id: u32, version: u32, embeddings: &[f32], - ) -> Result<(), SpannIndexWriterConstructionError> { + ) -> Result<(), SpannIndexWriterError> { let (ids, _) = self.rng_query(embeddings).await?; - // Create a centroid with just this point. + // The only cases when this can happen is initially when no data exists in the + // index or if all the data that was added to the index was deleted later. + // In both the cases, in the worst case, it can happen that ids is empty + // for the first few points getting inserted concurrently by different threads. + // It's fine to create new centers for each of them since the number of such points + // will be very small and we can also run GC to merge them later if needed. if ids.is_empty() { let next_id = self .next_head_id @@ -457,7 +499,7 @@ impl SpannIndexWriter { write_guard .set("", next_id, &posting_list) .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; } // Next add to hnsw. // This shouldn't exceed the capacity since this will happen only for the first few points @@ -466,7 +508,7 @@ impl SpannIndexWriter { let write_guard = self.hnsw_index.inner.write(); write_guard .add(next_id as usize, embeddings) - .map_err(|_| SpannIndexWriterConstructionError::HnswIndexAddError)?; + .map_err(|_| SpannIndexWriterError::HnswIndexAddError)?; } return Ok(()); } @@ -479,25 +521,26 @@ impl SpannIndexWriter { Ok(()) } - pub async fn add( - &self, - id: u32, - embedding: &[f32], - ) -> Result<(), SpannIndexWriterConstructionError> { + pub async fn add(&self, id: u32, embedding: &[f32]) -> Result<(), SpannIndexWriterError> { let version = self.add_versions_map(id); + // Normalize the embedding in case of cosine. + let mut normalized_embedding = embedding.to_vec(); + if self.distance_function == DistanceFunction::Cosine { + normalized_embedding = normalize(embedding); + } // Add to the posting list. - self.add_postings_list(id, version, embedding).await + self.add_to_postings_list(id, version, &normalized_embedding) + .await } - // TODO(Sanket): Change the error types. - pub async fn commit(self) -> Result { + pub async fn commit(self) -> Result { // Pl list. let pl_flusher = match Arc::try_unwrap(self.posting_list_writer) { Ok(writer) => writer .into_inner() .commit::>() .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?, + .map_err(|_| SpannIndexWriterError::PostingListCommitError)?, Err(_) => { // This should never happen. panic!("Failed to unwrap posting list writer"); @@ -510,7 +553,7 @@ impl SpannIndexWriter { .blockfile_provider .write::(bf_options) .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileWriterConstructionError)?; + .map_err(|_| SpannIndexWriterError::VersionsMapWriterCreateError)?; let versions_map_flusher = match Arc::try_unwrap(self.versions_map) { Ok(writer) => { let writer = writer.into_inner(); @@ -518,12 +561,12 @@ impl SpannIndexWriter { versions_map_bf_writer .set("", doc_offset_id, doc_version) .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + .map_err(|_| SpannIndexWriterError::VersionsMapSetError)?; } versions_map_bf_writer .commit::() .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)? + .map_err(|_| SpannIndexWriterError::VersionsMapCommitError)? } Err(_) => { // This should never happen. @@ -537,18 +580,18 @@ impl SpannIndexWriter { .blockfile_provider .write::<&str, u32>(bf_options) .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileWriterConstructionError)?; + .map_err(|_| SpannIndexWriterError::MaxHeadIdWriterCreateError)?; let max_head_id_flusher = match Arc::try_unwrap(self.next_head_id) { Ok(value) => { let value = value.into_inner(); max_head_id_bf .set("", MAX_HEAD_OFFSET_ID, value) .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)?; + .map_err(|_| SpannIndexWriterError::MaxHeadIdSetError)?; max_head_id_bf .commit::<&str, u32>() .await - .map_err(|_| SpannIndexWriterConstructionError::PostingListAddError)? + .map_err(|_| SpannIndexWriterError::MaxHeadIdCommitError)? } Err(_) => { // This should never happen. @@ -561,7 +604,7 @@ impl SpannIndexWriter { // Hnsw. self.hnsw_provider .commit(self.hnsw_index) - .map_err(|_| SpannIndexWriterConstructionError::HnswIndexConstructionError)?; + .map_err(|_| SpannIndexWriterError::HnswIndexCommitError)?; Ok(SpannIndexFlusher { pl_flusher, @@ -588,9 +631,8 @@ pub struct SpannIndexIds { pub hnsw_id: IndexUuid, } -// TODO(Sanket): Change the error types. impl SpannIndexFlusher { - pub async fn flush(self) -> Result { + pub async fn flush(self) -> Result { let res = SpannIndexIds { pl_id: self.pl_flusher.id(), versions_map_id: self.versions_map_flusher.id(), @@ -600,19 +642,19 @@ impl SpannIndexFlusher { self.pl_flusher .flush::>() .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileWriterConstructionError)?; + .map_err(|_| SpannIndexWriterError::PostingListFlushError)?; self.versions_map_flusher .flush::() .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileWriterConstructionError)?; + .map_err(|_| SpannIndexWriterError::VersionsMapFlushError)?; self.max_head_id_flusher .flush::<&str, u32>() .await - .map_err(|_| SpannIndexWriterConstructionError::BlockfileWriterConstructionError)?; + .map_err(|_| SpannIndexWriterError::MaxHeadIdFlushError)?; self.hnsw_flusher .flush(&self.hnsw_id) .await - .map_err(|_| SpannIndexWriterConstructionError::HnswIndexConstructionError)?; + .map_err(|_| SpannIndexWriterError::HnswIndexFlushError)?; Ok(res) } } diff --git a/rust/index/src/utils.rs b/rust/index/src/utils.rs index 0d9b3fe9542..04ffc7099bc 100644 --- a/rust/index/src/utils.rs +++ b/rust/index/src/utils.rs @@ -16,8 +16,6 @@ pub(super) fn generate_random_data(n: usize, d: usize) -> Vec { data } -const EPS: f32 = 1e-30; - pub fn merge_sorted_vecs_disjunction(a: &[T], b: &[T]) -> Vec { let mut result = Vec::with_capacity(a.len() + b.len()); let mut a_idx = 0; @@ -79,15 +77,6 @@ pub fn merge_sorted_vecs_conjunction(a: &[T], b: &[T]) -> Vec result } -pub fn normalize(vector: &[f32]) -> Vec { - let mut norm = 0.0; - for x in vector { - norm += x * x; - } - let norm = 1.0 / (norm.sqrt() + EPS); - vector.iter().map(|x| x * norm).collect() -} - #[cfg(test)] mod tests { #[test] diff --git a/rust/worker/src/execution/operators/brute_force_knn.rs b/rust/worker/src/execution/operators/brute_force_knn.rs index c82992c295f..b83aff738ed 100644 --- a/rust/worker/src/execution/operators/brute_force_knn.rs +++ b/rust/worker/src/execution/operators/brute_force_knn.rs @@ -3,9 +3,9 @@ use crate::segment::record_segment::RecordSegmentReader; use crate::segment::{materialize_logs, LogMaterializerError}; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; +use chroma_distance::normalize; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::utils::normalize; use chroma_types::Chunk; use chroma_types::{LogRecord, MaterializedLogOperation, Segment}; use std::cmp::Ordering; diff --git a/rust/worker/src/execution/operators/normalize_vectors.rs b/rust/worker/src/execution/operators/normalize_vectors.rs index f40b58a18eb..88a9e27c1c8 100644 --- a/rust/worker/src/execution/operators/normalize_vectors.rs +++ b/rust/worker/src/execution/operators/normalize_vectors.rs @@ -1,6 +1,6 @@ use crate::execution::operator::Operator; use async_trait::async_trait; -use chroma_index::utils::normalize; +use chroma_distance::normalize; #[derive(Debug)] pub struct NormalizeVectorOperator {} diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 1cf412c1ea0..f7dba01e825 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -36,10 +36,10 @@ use crate::{ }; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; +use chroma_distance::normalize; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; -use chroma_index::utils::normalize; use chroma_index::IndexConfig; use chroma_types::{Chunk, Collection, CollectionUuid, LogRecord, Segment, VectorQueryResult}; use std::collections::HashMap; diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index 379cc918749..b27f6361162 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -1,3 +1,4 @@ +use super::spann_segment::SpannSegmentWriterError; use super::types::{MaterializedLogRecord, SegmentWriter}; use super::SegmentFlusher; use async_trait::async_trait; @@ -322,6 +323,8 @@ pub enum ApplyMaterializedLogError { FullTextIndex(#[from] FullTextIndexError), #[error("Error writing to hnsw index")] HnswIndex(#[from] Box), + #[error("Error applying materialized records to spann segment: {0}")] + SpannSegmentError(#[from] SpannSegmentWriterError), } impl ChromaError for ApplyMaterializedLogError { @@ -333,6 +336,7 @@ impl ChromaError for ApplyMaterializedLogError { ApplyMaterializedLogError::Allocation => ErrorCodes::Internal, ApplyMaterializedLogError::FullTextIndex(e) => e.code(), ApplyMaterializedLogError::HnswIndex(_) => ErrorCodes::Internal, + ApplyMaterializedLogError::SpannSegmentError(e) => e.code(), } } } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index dda25deb406..1a037a0c007 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; use chroma_blockstore::provider::BlockfileProvider; +use chroma_distance::DistanceFunctionError; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::spann::types::SpannIndexFlusher; +use chroma_index::spann::types::{SpannIndexFlusher, SpannIndexWriterError}; use chroma_index::IndexUuid; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; use chroma_types::SegmentUuid; @@ -28,12 +29,13 @@ pub(crate) struct SpannSegmentWriter { id: SegmentUuid, } +// TODO(Sanket): Better error composability here. #[derive(Error, Debug)] pub enum SpannSegmentWriterError { #[error("Invalid argument")] InvalidArgument, - #[error("Segment metadata does not contain distance function")] - DistanceFunctionNotFound, + #[error("Segment metadata does not contain distance function {0}")] + DistanceFunctionNotFound(#[from] DistanceFunctionError), #[error("Error parsing index uuid from string")] IndexIdParsingError, #[error("Invalid file path for HNSW index")] @@ -46,8 +48,8 @@ pub enum SpannSegmentWriterError { MaxHeadIdInvalidFilePath, #[error("Error constructing spann index writer")] SpannSegmentWriterCreateError, - #[error("Error adding record to spann index writer")] - SpannSegmentWriterAddRecordError, + #[error("Error adding record to spann index writer {0}")] + SpannSegmentWriterAddRecordError(#[from] SpannIndexWriterError), #[error("Error committing spann index writer")] SpannSegmentWriterCommitError, #[error("Error flushing spann index writer")] @@ -61,7 +63,7 @@ impl ChromaError for SpannSegmentWriterError { match self { Self::InvalidArgument => ErrorCodes::InvalidArgument, Self::IndexIdParsingError => ErrorCodes::Internal, - Self::DistanceFunctionNotFound => ErrorCodes::Internal, + Self::DistanceFunctionNotFound(e) => e.code(), Self::HnswInvalidFilePath => ErrorCodes::Internal, Self::VersionMapInvalidFilePath => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, @@ -70,7 +72,7 @@ impl ChromaError for SpannSegmentWriterError { Self::NotImplemented => ErrorCodes::Internal, Self::SpannSegmentWriterCommitError => ErrorCodes::Internal, Self::SpannSegmentWriterFlushError => ErrorCodes::Internal, - Self::SpannSegmentWriterAddRecordError => ErrorCodes::Internal, + Self::SpannSegmentWriterAddRecordError(e) => e.code(), } } } @@ -88,8 +90,8 @@ impl SpannSegmentWriter { } let distance_function = match distance_function_from_segment(segment) { Ok(distance_function) => distance_function, - Err(_) => { - return Err(SpannSegmentWriterError::DistanceFunctionNotFound); + Err(e) => { + return Err(SpannSegmentWriterError::DistanceFunctionNotFound(*e)); } }; let (hnsw_id, m, ef_construction, ef_search) = match segment.file_path.get(HNSW_PATH) { @@ -202,7 +204,7 @@ impl SpannSegmentWriter { self.index .add(record.offset_id, record.merged_embeddings()) .await - .map_err(|_| SpannSegmentWriterError::SpannSegmentWriterAddRecordError) + .map_err(SpannSegmentWriterError::SpannSegmentWriterAddRecordError) } } @@ -210,17 +212,17 @@ struct SpannSegmentFlusher { index_flusher: SpannIndexFlusher, } -impl<'a> SegmentWriter<'a> for SpannSegmentWriter { +impl<'referred_data> SegmentWriter<'referred_data> for SpannSegmentWriter { async fn apply_materialized_log_chunk( &self, - records: chroma_types::Chunk>, + records: chroma_types::Chunk>, ) -> Result<(), ApplyMaterializedLogError> { for (record, _) in records.iter() { match record.final_operation { MaterializedLogOperation::AddNew => { self.add(record) .await - .map_err(|_| ApplyMaterializedLogError::BlockfileSet)?; + .map_err(ApplyMaterializedLogError::SpannSegmentError)?; } // TODO(Sanket): Implement other operations. _ => {