From 65b5e01b31920351969aa8e1da89de81c4c31000 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Thu, 5 Dec 2024 12:05:10 -0800 Subject: [PATCH] [ENH] Introduce spann segment reader (#3212) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Spann segment reader with from_segment impl - New functionality - ... ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None --- rust/blockstore/src/memory/storage.rs | 33 +++ rust/blockstore/src/types/value.rs | 6 + rust/index/src/spann/types.rs | 148 ++++++++++- rust/types/src/spann_posting_list.rs | 1 + rust/worker/src/segment/spann_segment.rs | 301 ++++++++++++++++++++++- 5 files changed, 477 insertions(+), 12 deletions(-) diff --git a/rust/blockstore/src/memory/storage.rs b/rust/blockstore/src/memory/storage.rs index 829740e2dea..0b0a297b3c5 100644 --- a/rust/blockstore/src/memory/storage.rs +++ b/rust/blockstore/src/memory/storage.rs @@ -680,6 +680,39 @@ impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> { } } +impl<'referred_data> Readable<'referred_data> for SpannPostingList<'referred_data> { + fn read_from_storage(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> Option { + todo!() + } + + fn read_range_from_storage<'prefix, PrefixRange, KeyRange>( + _: PrefixRange, + _: KeyRange, + _: &'referred_data Storage, + ) -> Vec<(&'referred_data CompositeKey, Self)> + where + PrefixRange: std::ops::RangeBounds<&'prefix str>, + KeyRange: std::ops::RangeBounds, + { + todo!() + } + + fn get_at_index( + _: &'referred_data Storage, + _: usize, + ) -> Option<(&'referred_data CompositeKey, Self)> { + todo!() + } + + fn count(_: &Storage) -> Result> { + todo!() + } + + fn contains(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> bool { + todo!() + } +} + #[derive(Clone)] pub struct StorageBuilder { bool_storage: Arc>>>, diff --git a/rust/blockstore/src/types/value.rs b/rust/blockstore/src/types/value.rs index a3f647e38bf..9e88e2462a9 100644 --- a/rust/blockstore/src/types/value.rs +++ b/rust/blockstore/src/types/value.rs @@ -61,6 +61,12 @@ impl Value for &DataRecord<'_> { } } +impl Value for SpannPostingList<'_> { + fn get_size(&self) -> usize { + self.compute_size() + } +} + impl Value for &SpannPostingList<'_> { fn get_size(&self) -> usize { self.compute_size() diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 9592744b68a..63dd6f08afb 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -5,7 +5,7 @@ use std::{ use chroma_blockstore::{ provider::{BlockfileProvider, CreateError, OpenError}, - BlockfileFlusher, BlockfileWriter, BlockfileWriterOptions, + BlockfileFlusher, BlockfileReader, BlockfileWriter, BlockfileWriterOptions, }; use chroma_distance::{normalize, DistanceFunction}; use chroma_error::{ChromaError, ErrorCodes}; @@ -1427,6 +1427,124 @@ impl SpannIndexFlusher { } } +#[derive(Error, Debug)] +pub enum SpannIndexReaderError { + #[error("Error creating/opening hnsw index")] + HnswIndexConstructionError, + #[error("Error creating/opening blockfile reader")] + BlockfileReaderConstructionError, + #[error("Spann index uninitialized")] + UninitializedIndex, +} + +impl ChromaError for SpannIndexReaderError { + fn code(&self) -> ErrorCodes { + match self { + Self::HnswIndexConstructionError => ErrorCodes::Internal, + Self::BlockfileReaderConstructionError => ErrorCodes::Internal, + Self::UninitializedIndex => ErrorCodes::Internal, + } + } +} + +#[derive(Clone)] +pub struct SpannIndexReader<'me> { + pub posting_lists: BlockfileReader<'me, u32, SpannPostingList<'me>>, + pub hnsw_index: HnswIndexRef, + pub versions_map: BlockfileReader<'me, u32, u32>, +} + +impl<'me> SpannIndexReader<'me> { + async fn hnsw_index_from_id( + hnsw_provider: &HnswIndexProvider, + id: &IndexUuid, + cache_key: &CollectionUuid, + distance_function: DistanceFunction, + dimensionality: usize, + ) -> Result { + match hnsw_provider.get(id, cache_key).await { + Some(index) => Ok(index), + None => { + match hnsw_provider + .open(id, cache_key, dimensionality as i32, distance_function) + .await + { + Ok(index) => Ok(index), + Err(_) => Err(SpannIndexReaderError::HnswIndexConstructionError), + } + } + } + } + + async fn posting_list_reader_from_id( + blockfile_id: &Uuid, + blockfile_provider: &BlockfileProvider, + ) -> Result>, SpannIndexReaderError> { + match blockfile_provider + .read::>(blockfile_id) + .await + { + Ok(reader) => Ok(reader), + Err(_) => Err(SpannIndexReaderError::BlockfileReaderConstructionError), + } + } + + async fn versions_map_reader_from_id( + blockfile_id: &Uuid, + blockfile_provider: &BlockfileProvider, + ) -> Result, SpannIndexReaderError> { + match blockfile_provider.read::(blockfile_id).await { + Ok(reader) => Ok(reader), + Err(_) => Err(SpannIndexReaderError::BlockfileReaderConstructionError), + } + } + + #[allow(clippy::too_many_arguments)] + pub async fn from_id( + hnsw_id: Option<&IndexUuid>, + hnsw_provider: &HnswIndexProvider, + hnsw_cache_key: &CollectionUuid, + distance_function: DistanceFunction, + dimensionality: usize, + pl_blockfile_id: Option<&Uuid>, + versions_map_blockfile_id: Option<&Uuid>, + blockfile_provider: &BlockfileProvider, + ) -> Result, SpannIndexReaderError> { + let hnsw_reader = match hnsw_id { + Some(hnsw_id) => { + Self::hnsw_index_from_id( + hnsw_provider, + hnsw_id, + hnsw_cache_key, + distance_function, + dimensionality, + ) + .await? + } + None => { + return Err(SpannIndexReaderError::UninitializedIndex); + } + }; + let postings_list_reader = match pl_blockfile_id { + Some(pl_id) => Self::posting_list_reader_from_id(pl_id, blockfile_provider).await?, + None => return Err(SpannIndexReaderError::UninitializedIndex), + }; + + let versions_map_reader = match versions_map_blockfile_id { + Some(versions_id) => { + Self::versions_map_reader_from_id(versions_id, blockfile_provider).await? + } + None => return Err(SpannIndexReaderError::UninitializedIndex), + }; + + Ok(Self { + posting_lists: postings_list_reader, + hnsw_index: hnsw_reader, + versions_map: versions_map_reader, + }) + } +} + #[cfg(test)] mod tests { use std::{f32::consts::PI, path::PathBuf}; @@ -1556,22 +1674,32 @@ mod tests { { // Posting list should have 100 points. let pl_read_guard = writer.posting_list_writer.lock().await; - let pl = pl_read_guard + let pl1 = pl_read_guard .get_owned::>("", emb_1_id) .await .expect("Error getting posting list") .unwrap(); - assert_eq!(pl.0.len(), 100); - assert_eq!(pl.1.len(), 100); - assert_eq!(pl.2.len(), 200); - let pl = pl_read_guard + let pl2 = pl_read_guard .get_owned::>("", emb_2_id) .await .expect("Error getting posting list") .unwrap(); - assert_eq!(pl.0.len(), 1); - assert_eq!(pl.1.len(), 1); - assert_eq!(pl.2.len(), 2); + // Only two combinations possible. + if pl1.0.len() == 100 { + assert_eq!(pl1.1.len(), 100); + assert_eq!(pl1.2.len(), 200); + assert_eq!(pl2.0.len(), 1); + assert_eq!(pl2.1.len(), 1); + assert_eq!(pl2.2.len(), 2); + } else if pl2.0.len() == 100 { + assert_eq!(pl2.1.len(), 100); + assert_eq!(pl2.2.len(), 200); + assert_eq!(pl1.0.len(), 1); + assert_eq!(pl1.1.len(), 1); + assert_eq!(pl1.2.len(), 2); + } else { + panic!("Invalid posting list lengths"); + } } // Next insert 99 points in the region of (1000.0, 1000.0) for i in 102..=200 { @@ -1911,7 +2039,7 @@ mod tests { version_map_guard.versions_map.insert(100 + point as u32, 1); } } - // Delete 60 points each from the centers. Since merge_threshold is 40, this should + // Delete 60 points each from the centers. Since merge_threshold is 50, this should // trigger a merge between the two centers. for point in 1..=60 { writer diff --git a/rust/types/src/spann_posting_list.rs b/rust/types/src/spann_posting_list.rs index c57ec8ad75c..413234c63f8 100644 --- a/rust/types/src/spann_posting_list.rs +++ b/rust/types/src/spann_posting_list.rs @@ -1,3 +1,4 @@ +#[derive(Clone, Debug)] pub struct SpannPostingList<'referred_data> { pub doc_offset_ids: &'referred_data [u32], pub doc_versions: &'referred_data [u32], diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 801c7e565c7..7b073949486 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -3,7 +3,9 @@ use std::collections::HashMap; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunctionError; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::spann::types::{SpannIndexFlusher, SpannIndexWriterError}; +use chroma_index::spann::types::{ + SpannIndexFlusher, SpannIndexReader, SpannIndexReaderError, SpannIndexWriterError, +}; use chroma_index::IndexUuid; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; use chroma_types::SegmentUuid; @@ -307,6 +309,147 @@ impl SegmentFlusher for SpannSegmentFlusher { } } +#[derive(Error, Debug)] +pub enum SpannSegmentReaderError { + #[error("Invalid argument")] + InvalidArgument, + #[error("Segment metadata does not contain distance function")] + DistanceFunctionNotFound, + #[error("Error parsing index uuid from string")] + IndexIdParsingError, + #[error("Invalid file path for HNSW index")] + HnswInvalidFilePath, + #[error("Invalid file path for version map")] + VersionMapInvalidFilePath, + #[error("Invalid file path for posting list")] + PostingListInvalidFilePath, + #[error("Error constructing spann index reader")] + SpannSegmentReaderCreateError, + #[error("Spann segment is uninitialized")] + UninitializedSegment, +} + +impl ChromaError for SpannSegmentReaderError { + fn code(&self) -> ErrorCodes { + match self { + Self::InvalidArgument => ErrorCodes::InvalidArgument, + Self::IndexIdParsingError => ErrorCodes::Internal, + Self::DistanceFunctionNotFound => ErrorCodes::Internal, + Self::HnswInvalidFilePath => ErrorCodes::Internal, + Self::VersionMapInvalidFilePath => ErrorCodes::Internal, + Self::PostingListInvalidFilePath => ErrorCodes::Internal, + Self::SpannSegmentReaderCreateError => ErrorCodes::Internal, + Self::UninitializedSegment => ErrorCodes::Internal, + } + } +} + +#[derive(Clone)] +#[allow(dead_code)] +pub(crate) struct SpannSegmentReader<'me> { + index_reader: SpannIndexReader<'me>, + id: SegmentUuid, +} + +impl<'me> SpannSegmentReader<'me> { + #[allow(dead_code)] + pub async fn from_segment( + segment: &Segment, + blockfile_provider: &BlockfileProvider, + hnsw_provider: &HnswIndexProvider, + dimensionality: usize, + ) -> Result, SpannSegmentReaderError> { + if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR { + return Err(SpannSegmentReaderError::InvalidArgument); + } + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(_) => { + return Err(SpannSegmentReaderError::DistanceFunctionNotFound); + } + }; + let hnsw_id = match segment.file_path.get(HNSW_PATH) { + Some(hnsw_path) => match hnsw_path.first() { + Some(index_id) => { + let index_uuid = match Uuid::parse_str(index_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentReaderError::IndexIdParsingError); + } + }; + Some(IndexUuid(index_uuid)) + } + None => { + return Err(SpannSegmentReaderError::HnswInvalidFilePath); + } + }, + None => None, + }; + let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) { + Some(version_map_path) => match version_map_path.first() { + Some(version_map_id) => { + let version_map_uuid = match Uuid::parse_str(version_map_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentReaderError::IndexIdParsingError); + } + }; + Some(version_map_uuid) + } + None => { + return Err(SpannSegmentReaderError::VersionMapInvalidFilePath); + } + }, + None => None, + }; + let posting_list_id = match segment.file_path.get(POSTING_LIST_PATH) { + Some(posting_list_path) => match posting_list_path.first() { + Some(posting_list_id) => { + let posting_list_uuid = match Uuid::parse_str(posting_list_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentReaderError::IndexIdParsingError); + } + }; + Some(posting_list_uuid) + } + None => { + return Err(SpannSegmentReaderError::PostingListInvalidFilePath); + } + }, + None => None, + }; + + let index_reader = match SpannIndexReader::from_id( + hnsw_id.as_ref(), + hnsw_provider, + &segment.collection, + distance_function, + dimensionality, + posting_list_id.as_ref(), + versions_map_id.as_ref(), + blockfile_provider, + ) + .await + { + Ok(index_writer) => index_writer, + Err(e) => match e { + SpannIndexReaderError::UninitializedIndex => { + return Err(SpannSegmentReaderError::UninitializedSegment); + } + _ => { + return Err(SpannSegmentReaderError::SpannSegmentReaderCreateError); + } + }, + }; + + Ok(SpannSegmentReader { + index_reader, + id: segment.id, + }) + } +} + #[cfg(test)] mod test { use std::{collections::HashMap, path::PathBuf}; @@ -325,7 +468,9 @@ mod test { }; use crate::segment::{ - materialize_logs, spann_segment::SpannSegmentWriter, SegmentFlusher, SegmentWriter, + materialize_logs, + spann_segment::{SpannSegmentReader, SpannSegmentWriter}, + SegmentFlusher, SegmentWriter, }; #[tokio::test] @@ -512,4 +657,156 @@ mod test { assert_eq!(res.2[4], 5.0); assert_eq!(res.2[5], 6.0); } + + #[tokio::test] + async fn test_spann_segment_reader() { + // Tests that after the writer writes and flushes data, reader is able + // to read it. + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let collection_id = CollectionUuid::new(); + let segment_id = SegmentUuid::new(); + let mut metadata_hash_map = Metadata::new(); + metadata_hash_map.insert( + "hnsw:space".to_string(), + MetadataValue::Str("l2".to_string()), + ); + metadata_hash_map.insert("hnsw:M".to_string(), MetadataValue::Int(16)); + metadata_hash_map.insert("hnsw:construction_ef".to_string(), MetadataValue::Int(100)); + metadata_hash_map.insert("hnsw:search_ef".to_string(), MetadataValue::Int(100)); + let mut spann_segment = chroma_types::Segment { + id: segment_id, + collection: collection_id, + r#type: chroma_types::SegmentType::Spann, + scope: chroma_types::SegmentScope::VECTOR, + metadata: Some(metadata_hash_map), + file_path: HashMap::new(), + }; + let spann_writer = SpannSegmentWriter::from_segment( + &spann_segment, + &blockfile_provider, + &hnsw_provider, + 3, + ) + .await + .expect("Error creating spann segment writer"); + let data = vec![ + LogRecord { + log_offset: 1, + record: OperationRecord { + id: "embedding_id_1".to_string(), + embedding: Some(vec![1.0, 2.0, 3.0]), + encoding: None, + metadata: None, + document: Some(String::from("This is a document about cats.")), + operation: Operation::Add, + }, + }, + LogRecord { + log_offset: 2, + record: OperationRecord { + id: "embedding_id_2".to_string(), + embedding: Some(vec![4.0, 5.0, 6.0]), + encoding: None, + metadata: None, + document: Some(String::from("This is a document about dogs.")), + operation: Operation::Add, + }, + }, + ]; + let chunked_log = Chunk::new(data.into()); + // Materialize the logs. + let materialized_log = materialize_logs(&None, &chunked_log, None) + .await + .expect("Error materializing logs"); + spann_writer + .apply_materialized_log_chunk(materialized_log) + .await + .expect("Error applying materialized log"); + let flusher = spann_writer + .commit() + .await + .expect("Error committing spann writer"); + spann_segment.file_path = flusher.flush().await.expect("Error flushing spann writer"); + assert_eq!(spann_segment.file_path.len(), 4); + assert!(spann_segment.file_path.contains_key("hnsw_path")); + assert!(spann_segment.file_path.contains_key("version_map_path"),); + assert!(spann_segment.file_path.contains_key("posting_list_path"),); + assert!(spann_segment.file_path.contains_key("max_head_id_path"),); + // Load this segment and check if the embeddings are present. New cache + // so that the previous cache is not used. + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage, + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let spann_reader = SpannSegmentReader::from_segment( + &spann_segment, + &blockfile_provider, + &hnsw_provider, + 3, + ) + .await + .expect("Error creating segment reader"); + let (non_deleted_centers, deleted_centers) = spann_reader + .index_reader + .hnsw_index + .inner + .read() + .get_all_ids() + .expect("Error getting all ids from hnsw index"); + assert_eq!(non_deleted_centers.len(), 1); + assert_eq!(deleted_centers.len(), 0); + assert_eq!(non_deleted_centers[0], 1); + let mut pl = spann_reader + .index_reader + .posting_lists + .get_range(.., ..) + .await + .expect("Error getting all data from reader"); + pl.sort_by(|a, b| a.0.cmp(&b.0)); + assert_eq!(pl.len(), 1); + assert_eq!(pl[0].1.doc_offset_ids, &[1, 2]); + assert_eq!(pl[0].1.doc_versions, &[1, 1]); + assert_eq!(pl[0].1.doc_embeddings, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut versions_map = spann_reader + .index_reader + .versions_map + .get_range(.., ..) + .await + .expect("Error gettting all data from reader"); + versions_map.sort_by(|a, b| a.0.cmp(&b.0)); + assert_eq!(versions_map, vec![(1, 1), (2, 1)]); + } }