From 08bcd82d5a79603ceb21b955cb39a0a55c36f94e Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Tue, 26 Nov 2024 15:17:21 -0800 Subject: [PATCH] Add tests --- rust/index/src/hnsw.rs | 4 +- rust/index/src/spann/types.rs | 381 ++++++++++++++++++++++++++++++++++ 2 files changed, 383 insertions(+), 2 deletions(-) diff --git a/rust/index/src/hnsw.rs b/rust/index/src/hnsw.rs index aff3848ee184..68e5a72a1e22 100644 --- a/rust/index/src/hnsw.rs +++ b/rust/index/src/hnsw.rs @@ -233,7 +233,7 @@ impl Index for HnswIndex { fn get_all_ids_sizes(&self) -> Result, Box> { let mut sizes = vec![0usize; 2]; - unsafe { get_all_ids_size(self.ffi_ptr, sizes.as_mut_ptr()) }; + unsafe { get_all_ids_sizes(self.ffi_ptr, sizes.as_mut_ptr()) }; read_and_return_hnsw_error(self.ffi_ptr)?; Ok(sizes) } @@ -381,7 +381,7 @@ extern "C" { fn add_item(index: *const IndexPtrFFI, data: *const f32, id: usize, replace_deleted: bool); fn mark_deleted(index: *const IndexPtrFFI, id: usize); fn get_item(index: *const IndexPtrFFI, id: usize, data: *mut f32); - fn get_all_ids_size(index: *const IndexPtrFFI, sizes: *mut usize); + fn get_all_ids_sizes(index: *const IndexPtrFFI, sizes: *mut usize); fn get_all_ids(index: *const IndexPtrFFI, non_deleted_ids: *mut usize, deleted_ids: *mut usize); fn knn_query( index: *const IndexPtrFFI, diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index bc9e0a9932d6..ff8b80bcbe18 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1118,6 +1118,10 @@ impl SpannIndexWriter { .into_iter() .zip(nearest_head_embeddings.into_iter()) { + // Skip if it is the current head. Can't a merge a head into itself. + if nearest_head_id == head_id { + continue; + } // TODO(Sanket): If and when the lock is more fine grained, then // need to acquire a lock on the nearest_head_id here. // TODO(Sanket): Also need to check if the head is deleted concurrently then. @@ -1255,6 +1259,7 @@ impl SpannIndexWriter { .get(head_id) .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)? .ok_or(SpannIndexWriterError::HnswIndexSearchError)?; + tracing::info!("Garbage collecting head {}", head_id); self.garbage_collect_head(head_id, &head_embedding).await?; } Ok(()) @@ -1584,4 +1589,380 @@ mod tests { assert_eq!(pl.2.len(), 200); } } + + #[tokio::test] + async fn test_gc_deletes() { + // Insert a few entries in a couple of centers. Delete a few + // still keeping within the merge threshold. + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Insert a couple of centers. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + } + { + let pl_guard = writer.posting_list_writer.lock().await; + let mut doc_offset_ids = vec![0u32; 100]; + let mut doc_versions = vec![0; 100]; + let mut doc_embeddings = vec![0.0; 200]; + // Insert 100 points in each of the centers. + for point in 1..=100 { + doc_offset_ids[point - 1] = point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = point as f32; + doc_embeddings[(point - 1) * 2 + 1] = point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 1, &pl) + .await + .expect("Error writing to posting list"); + for point in 1..=100 { + doc_offset_ids[point - 1] = 100 + point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = 1000.0 + point as f32; + doc_embeddings[(point - 1) * 2 + 1] = 1000.0 + point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 2, &pl) + .await + .expect("Error writing to posting list"); + } + // Insert the points in the version map as well. + { + let mut version_map_guard = writer.versions_map.write(); + for point in 1..=100 { + version_map_guard.versions_map.insert(point as u32, 1); + version_map_guard.versions_map.insert(100 + point as u32, 1); + } + } + // Delete 40 points each from the centers. + for point in 1..=40 { + writer + .delete(point) + .await + .expect("Error deleting from spann index writer"); + writer + .delete(100 + point) + .await + .expect("Error deleting from spann index writer"); + } + // Expect the version map to be properly updated. + { + let version_map_guard = writer.versions_map.read(); + for point in 1..=40 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&0)); + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); + } + // For the other 60 points, the version should be 1. + for point in 41..=100 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&1)); + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&1)); + } + } + { + // The posting lists should not be changed at all. + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + } + // Now garbage collect. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // Expect the posting lists to be 60. Also validate the ids, versions and embeddings + // individually. + { + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 60); + assert_eq!(pl.1.len(), 60); + assert_eq!(pl.2.len(), 120); + for point in 41..=100 { + assert_eq!(pl.0[point - 41], point as u32); + assert_eq!(pl.1[point - 41], 1); + assert_eq!(pl.2[(point - 41) * 2], point as f32); + assert_eq!(pl.2[(point - 41) * 2 + 1], point as f32); + } + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 60); + assert_eq!(pl.1.len(), 60); + assert_eq!(pl.2.len(), 120); + for point in 41..=100 { + assert_eq!(pl.0[point - 41], 100 + point as u32); + assert_eq!(pl.1[point - 41], 1); + assert_eq!(pl.2[(point - 41) * 2], 1000.0 + point as f32); + assert_eq!(pl.2[(point - 41) * 2 + 1], 1000.0 + point as f32); + } + } + } + + #[tokio::test] + async fn test_merge() { + // Insert a few entries in a couple of centers. Delete a few + // still keeping within the merge threshold. + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Insert a couple of centers. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + } + { + let pl_guard = writer.posting_list_writer.lock().await; + let mut doc_offset_ids = vec![0u32; 100]; + let mut doc_versions = vec![0; 100]; + let mut doc_embeddings = vec![0.0; 200]; + // Insert 100 points in each of the centers. + for point in 1..=100 { + doc_offset_ids[point - 1] = point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = point as f32; + doc_embeddings[(point - 1) * 2 + 1] = point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 1, &pl) + .await + .expect("Error writing to posting list"); + for point in 1..=100 { + doc_offset_ids[point - 1] = 100 + point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = 1000.0 + point as f32; + doc_embeddings[(point - 1) * 2 + 1] = 1000.0 + point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 2, &pl) + .await + .expect("Error writing to posting list"); + } + // Insert the points in the version map as well. + { + let mut version_map_guard = writer.versions_map.write(); + for point in 1..=100 { + version_map_guard.versions_map.insert(point as u32, 1); + version_map_guard.versions_map.insert(100 + point as u32, 1); + } + } + // Delete 60 points each from the centers. Since merge_threshold is 40, this should + // trigger a merge between the two centers. + for point in 1..=60 { + writer + .delete(point) + .await + .expect("Error deleting from spann index writer"); + writer + .delete(100 + point) + .await + .expect("Error deleting from spann index writer"); + } + // Just one more point from the latter center. + writer + .delete(100 + 61) + .await + .expect("Error deleting from spann index writer"); + // Expect the version map to be properly updated. + { + let version_map_guard = writer.versions_map.read(); + for point in 1..=60 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&0)); + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); + } + // For the other 60 points, the version should be 1. + for point in 61..=100 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&1)); + if point == 61 { + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); + } else { + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&1)); + } + } + } + { + // The posting lists should not be changed at all. + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + } + // Now garbage collect. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // Expect only one center now. [0.0, 0.0] + { + let hnsw_read_guard = writer.hnsw_index.inner.read(); + assert_eq!(hnsw_read_guard.len(), 1); + let (non_deleted_ids, deleted_ids) = hnsw_read_guard + .get_all_ids() + .expect("Error getting all ids"); + assert_eq!(non_deleted_ids.len(), 1); + assert_eq!(deleted_ids.len(), 1); + assert_eq!(non_deleted_ids[0], 1); + assert_eq!(deleted_ids[0], 2); + let emb = hnsw_read_guard + .get(non_deleted_ids[0]) + .expect("Error getting hnsw index") + .unwrap(); + assert_eq!(emb, &[0.0, 0.0]); + } + // Expect the posting lists with id 1 to be 79. + { + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 79); + assert_eq!(pl.1.len(), 79); + assert_eq!(pl.2.len(), 158); + } + } }