diff --git a/rust/index/bindings.cpp b/rust/index/bindings.cpp index 6a0f9165b1b..280d24cbd7d 100644 --- a/rust/index/bindings.cpp +++ b/rust/index/bindings.cpp @@ -129,6 +129,34 @@ class Index } } + void get_all_ids_sizes(size_t *ids_sizes) + { + if (!index_inited) + { + throw std::runtime_error("Index not inited"); + } + auto res = appr_alg->getLabelCounts(); + ids_sizes[0] = res.first; + ids_sizes[1] = res.second; + } + + void get_all_ids(hnswlib::labeltype *non_deleted_ids, hnswlib::labeltype *deleted_ids) + { + if (!index_inited) + { + throw std::runtime_error("Index not inited"); + } + auto res = appr_alg->getAllLabels(); + for (int i = 0; i < res.first.size(); i++) + { + non_deleted_ids[i] = res.first[i]; + } + for (int i = 0; i < res.second.size(); i++) + { + deleted_ids[i] = res.second[i]; + } + } + void mark_deleted(const hnswlib::labeltype id) { if (!index_inited) @@ -307,6 +335,34 @@ extern "C" last_error.clear(); } + void get_all_ids_sizes(Index *index, size_t *ids_sizes) + { + try + { + index->get_all_ids_sizes(ids_sizes); + } + catch (std::exception &e) + { + last_error = e.what(); + return; + } + last_error.clear(); + } + + void get_all_ids(Index *index, hnswlib::labeltype *non_deleted_ids, hnswlib::labeltype *deleted_ids) + { + try + { + index->get_all_ids(non_deleted_ids, deleted_ids); + } + catch (std::exception &e) + { + last_error = e.what(); + return; + } + last_error.clear(); + } + // Can throw std::exception void mark_deleted(Index *index, const hnswlib::labeltype id) { diff --git a/rust/index/src/hnsw.rs b/rust/index/src/hnsw.rs index e67433e4eb1..68e5a72a1e2 100644 --- a/rust/index/src/hnsw.rs +++ b/rust/index/src/hnsw.rs @@ -230,6 +230,28 @@ impl Index for HnswIndex { Ok(Some(data)) } } + + fn get_all_ids_sizes(&self) -> Result, Box> { + let mut sizes = vec![0usize; 2]; + unsafe { get_all_ids_sizes(self.ffi_ptr, sizes.as_mut_ptr()) }; + read_and_return_hnsw_error(self.ffi_ptr)?; + Ok(sizes) + } + + fn get_all_ids(&self) -> Result<(Vec, Vec), Box> { + let sizes = self.get_all_ids_sizes()?; + let mut non_deleted_ids = vec![0usize; sizes[0]]; + let mut deleted_ids = vec![0usize; sizes[1]]; + unsafe { + get_all_ids( + self.ffi_ptr, + non_deleted_ids.as_mut_ptr(), + deleted_ids.as_mut_ptr(), + ); + } + read_and_return_hnsw_error(self.ffi_ptr)?; + Ok((non_deleted_ids, deleted_ids)) + } } impl PersistentIndex for HnswIndex { @@ -359,6 +381,8 @@ extern "C" { fn add_item(index: *const IndexPtrFFI, data: *const f32, id: usize, replace_deleted: bool); fn mark_deleted(index: *const IndexPtrFFI, id: usize); fn get_item(index: *const IndexPtrFFI, id: usize, data: *mut f32); + fn get_all_ids_sizes(index: *const IndexPtrFFI, sizes: *mut usize); + fn get_all_ids(index: *const IndexPtrFFI, non_deleted_ids: *mut usize, deleted_ids: *mut usize); fn knn_query( index: *const IndexPtrFFI, query_vector: *const f32, diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 5f7e0188807..9592744b68a 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -147,6 +147,8 @@ const NUM_SAMPLES_FOR_KMEANS: usize = 1000; const INITIAL_LAMBDA: f32 = 100.0; const REASSIGN_NBR_COUNT: usize = 8; const QUERY_EPSILON: f32 = 10.0; +const MERGE_THRESHOLD: usize = 50; +const NUM_CENTERS_TO_MERGE_TO: usize = 8; impl SpannIndexWriter { #[allow(clippy::too_many_arguments)] @@ -504,6 +506,7 @@ impl SpannIndexWriter { .query(head_embedding, k, &allowed_ids, &disallowed_ids) .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)?; // Get the embeddings also for distance computation. + // TODO(Sanket): Don't consider heads that are farther away than the closest. for id in nearest_ids.iter() { let emb = read_guard .get(*id) @@ -698,8 +701,9 @@ impl SpannIndexWriter { let clustering_output; { let write_guard = self.posting_list_writer.lock().await; - // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread - // deletes it. + if self.is_head_deleted(head_id as usize).await? { + return Ok(()); + } let (mut doc_offset_ids, mut doc_versions, mut doc_embeddings) = write_guard .get_owned::>("", head_id) .await @@ -989,6 +993,314 @@ impl SpannIndexWriter { Ok(()) } + async fn get_up_to_date_count( + &self, + doc_offset_ids: &[u32], + doc_versions: &[u32], + ) -> Result { + let mut up_to_date_index = 0; + let version_map_guard = self.versions_map.read(); + for (index, doc_version) in doc_versions.iter().enumerate() { + let current_version = version_map_guard + .versions_map + .get(&doc_offset_ids[index]) + .ok_or(SpannIndexWriterError::VersionNotFound)?; + // disregard if either deleted or on an older version. + if *current_version == 0 || doc_version < current_version { + continue; + } + up_to_date_index += 1; + } + Ok(up_to_date_index) + } + + async fn is_head_deleted(&self, head_id: usize) -> Result { + let hnsw_read_guard = self.hnsw_index.inner.read(); + let hnsw_emb = hnsw_read_guard.get(head_id); + // TODO(Sanket): Check for exact error. + // TODO(Sanket): We should get this information from hnswlib and not rely on error. + if hnsw_emb.is_err() || hnsw_emb.unwrap().is_none() { + return Ok(true); + } + Ok(false) + } + + async fn remove_outdated_entries( + &self, + mut doc_offset_ids: Vec, + mut doc_versions: Vec, + mut doc_embeddings: Vec, + ) -> Result<(Vec, Vec, Vec), SpannIndexWriterError> { + let mut cluster_len = 0; + let mut local_indices = vec![0; doc_offset_ids.len()]; + { + let version_map_guard = self.versions_map.read(); + for (index, doc_version) in doc_versions.iter().enumerate() { + let current_version = version_map_guard + .versions_map + .get(&doc_offset_ids[index]) + .ok_or(SpannIndexWriterError::VersionNotFound)?; + // disregard if either deleted or on an older version. + if *current_version == 0 || doc_version < current_version { + continue; + } + local_indices[cluster_len] = index; + cluster_len += 1; + } + } + for idx in 0..cluster_len { + if local_indices[idx] == idx { + continue; + } + doc_offset_ids[idx] = doc_offset_ids[local_indices[idx]]; + doc_versions[idx] = doc_versions[local_indices[idx]]; + doc_embeddings.copy_within( + local_indices[idx] * self.dimensionality + ..(local_indices[idx] + 1) * self.dimensionality, + idx * self.dimensionality, + ); + } + doc_offset_ids.truncate(cluster_len); + doc_versions.truncate(cluster_len); + doc_embeddings.truncate(cluster_len * self.dimensionality); + Ok((doc_offset_ids, doc_versions, doc_embeddings)) + } + + #[allow(clippy::too_many_arguments)] + async fn merge_posting_lists( + &self, + mut source_doc_offset_ids: Vec, + mut source_doc_versions: Vec, + mut source_doc_embeddings: Vec, + target_doc_offset_ids: Vec, + target_doc_versions: Vec, + target_doc_embeddings: Vec, + target_cluster_len: usize, + ) -> Result<(Vec, Vec, Vec), SpannIndexWriterError> { + source_doc_embeddings.reserve_exact(target_cluster_len); + source_doc_versions.reserve_exact(target_cluster_len); + source_doc_embeddings.reserve_exact(target_cluster_len * self.dimensionality); + for (index, target_doc_offset_id) in target_doc_offset_ids.into_iter().enumerate() { + if self + .is_outdated(target_doc_offset_id, target_doc_versions[index]) + .await? + { + continue; + } + source_doc_offset_ids.push(target_doc_offset_id); + source_doc_versions.push(target_doc_versions[index]); + source_doc_embeddings.extend_from_slice( + &target_doc_embeddings + [index * self.dimensionality..(index + 1) * self.dimensionality], + ); + } + Ok(( + source_doc_offset_ids, + source_doc_versions, + source_doc_embeddings, + )) + } + + async fn garbage_collect_head( + &self, + head_id: usize, + head_embedding: &[f32], + ) -> Result<(), SpannIndexWriterError> { + // Get heads. + let mut merged_with_a_nbr = false; + let source_cluster_len; + let mut target_cluster_len = 0; + let mut doc_offset_ids; + let mut doc_versions; + let mut doc_embeddings; + let mut target_embedding = vec![]; + let mut target_head = 0; + { + let pl_guard = self.posting_list_writer.lock().await; + // If head is concurrently deleted then skip. + if self.is_head_deleted(head_id).await? { + return Ok(()); + } + (doc_offset_ids, doc_versions, doc_embeddings) = pl_guard + .get_owned::>("", head_id as u32) + .await + .map_err(|_| SpannIndexWriterError::PostingListGetError)? + .ok_or(SpannIndexWriterError::PostingListGetError)?; + (doc_offset_ids, doc_versions, doc_embeddings) = self + .remove_outdated_entries(doc_offset_ids, doc_versions, doc_embeddings) + .await?; + source_cluster_len = doc_offset_ids.len(); + // Write the PL back and return if within the merge threshold. + if source_cluster_len > MERGE_THRESHOLD { + let posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", head_id as u32, &posting_list) + .await + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; + + return Ok(()); + } + // Find candidates for merge. + let (nearest_head_ids, _, nearest_head_embeddings) = self + .get_nearby_heads(head_embedding, NUM_CENTERS_TO_MERGE_TO) + .await?; + for (nearest_head_id, nearest_head_embedding) in nearest_head_ids + .into_iter() + .zip(nearest_head_embeddings.into_iter()) + { + // Skip if it is the current head. Can't a merge a head into itself. + if nearest_head_id == head_id { + continue; + } + // TODO(Sanket): If and when the lock is more fine grained, then + // need to acquire a lock on the nearest_head_id here. + // TODO(Sanket): Also need to check if the head is deleted concurrently then. + let ( + nearest_head_doc_offset_ids, + nearest_head_doc_versions, + nearest_head_doc_embeddings, + ) = pl_guard + .get_owned::>("", nearest_head_id as u32) + .await + .map_err(|_| SpannIndexWriterError::PostingListGetError)? + .ok_or(SpannIndexWriterError::PostingListGetError)?; + target_cluster_len = self + .get_up_to_date_count(&nearest_head_doc_offset_ids, &nearest_head_doc_versions) + .await?; + // If the total count exceeds the max posting list size then skip. + if target_cluster_len + source_cluster_len >= SPLIT_THRESHOLD { + continue; + } + // Merge the two PLs. + (doc_offset_ids, doc_versions, doc_embeddings) = self + .merge_posting_lists( + doc_offset_ids, + doc_versions, + doc_embeddings, + nearest_head_doc_offset_ids, + nearest_head_doc_versions, + nearest_head_doc_embeddings, + target_cluster_len, + ) + .await?; + // Write the merged PL back. + // Merge into the larger of the two clusters. + let merged_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + if target_cluster_len > source_cluster_len { + pl_guard + .set("", nearest_head_id as u32, &merged_posting_list) + .await + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; + // Delete from hnsw. + let hnsw_write_guard = self.hnsw_index.inner.write(); + hnsw_write_guard + .delete(head_id) + .map_err(|_| SpannIndexWriterError::HnswIndexAddError)?; + } else { + pl_guard + .set("", head_id as u32, &merged_posting_list) + .await + .map_err(|_| SpannIndexWriterError::PostingListSetError)?; + // Delete from hnsw. + let hnsw_write_guard = self.hnsw_index.inner.write(); + hnsw_write_guard + .delete(nearest_head_id) + .map_err(|_| SpannIndexWriterError::HnswIndexAddError)?; + } + // This center is now merged with a neighbor. + target_head = nearest_head_id; + target_embedding = nearest_head_embedding; + merged_with_a_nbr = true; + break; + } + } + if !merged_with_a_nbr { + return Ok(()); + } + // Reassign points that were merged to neighbouring heads. + if source_cluster_len > target_cluster_len { + // target_cluster points were merged to source_cluster + // so they are candidates for reassignment. + for idx in source_cluster_len..(source_cluster_len + target_cluster_len) { + let origin_dist = self.distance_function.distance( + &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], + &target_embedding, + ); + let new_dist = self.distance_function.distance( + &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], + head_embedding, + ); + if new_dist > origin_dist { + self.reassign( + doc_offset_ids[idx], + doc_versions[idx], + &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], + head_id as u32, + ) + .await?; + } + } + } else { + // source_cluster points were merged to target_cluster + // so they are candidates for reassignment. + for idx in 0..source_cluster_len { + let origin_dist = self.distance_function.distance( + &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], + head_embedding, + ); + let new_dist = self.distance_function.distance( + &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], + &target_embedding, + ); + if new_dist > origin_dist { + self.reassign( + doc_offset_ids[idx], + doc_versions[idx], + &doc_embeddings[idx * self.dimensionality..(idx + 1) * self.dimensionality], + target_head as u32, + ) + .await?; + } + } + } + Ok(()) + } + + // TODO(Sanket): Hook in the gc policy. + // TODO(Sanket): Garbage collect HNSW also. + pub async fn garbage_collect(&self) -> Result<(), SpannIndexWriterError> { + // Get all the heads. + let non_deleted_heads; + { + let hnsw_read_guard = self.hnsw_index.inner.read(); + (non_deleted_heads, _) = hnsw_read_guard + .get_all_ids() + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)?; + } + // Iterate over all the heads and gc heads. + for head_id in non_deleted_heads.into_iter() { + let head_embedding = self + .hnsw_index + .inner + .read() + .get(head_id) + .map_err(|_| SpannIndexWriterError::HnswIndexSearchError)? + .ok_or(SpannIndexWriterError::HnswIndexSearchError)?; + tracing::info!("Garbage collecting head {}", head_id); + self.garbage_collect_head(head_id, &head_embedding).await?; + } + Ok(()) + } + + // TODO(Sanket): Change the error types. pub async fn commit(self) -> Result { // Pl list. let pl_flusher = match Arc::try_unwrap(self.posting_list_writer) { @@ -1312,4 +1624,861 @@ mod tests { assert_eq!(pl.2.len(), 200); } } + + #[tokio::test] + async fn test_gc_deletes() { + // Insert a few entries in a couple of centers. Delete a few + // still keeping within the merge threshold. + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Insert a couple of centers. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + } + { + let pl_guard = writer.posting_list_writer.lock().await; + let mut doc_offset_ids = vec![0u32; 100]; + let mut doc_versions = vec![0; 100]; + let mut doc_embeddings = vec![0.0; 200]; + // Insert 100 points in each of the centers. + for point in 1..=100 { + doc_offset_ids[point - 1] = point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = point as f32; + doc_embeddings[(point - 1) * 2 + 1] = point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 1, &pl) + .await + .expect("Error writing to posting list"); + for point in 1..=100 { + doc_offset_ids[point - 1] = 100 + point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = 1000.0 + point as f32; + doc_embeddings[(point - 1) * 2 + 1] = 1000.0 + point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 2, &pl) + .await + .expect("Error writing to posting list"); + } + // Insert the points in the version map as well. + { + let mut version_map_guard = writer.versions_map.write(); + for point in 1..=100 { + version_map_guard.versions_map.insert(point as u32, 1); + version_map_guard.versions_map.insert(100 + point as u32, 1); + } + } + // Delete 40 points each from the centers. + for point in 1..=40 { + writer + .delete(point) + .await + .expect("Error deleting from spann index writer"); + writer + .delete(100 + point) + .await + .expect("Error deleting from spann index writer"); + } + // Expect the version map to be properly updated. + { + let version_map_guard = writer.versions_map.read(); + for point in 1..=40 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&0)); + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); + } + // For the other 60 points, the version should be 1. + for point in 41..=100 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&1)); + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&1)); + } + } + { + // The posting lists should not be changed at all. + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + } + // Now garbage collect. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // Expect the posting lists to be 60. Also validate the ids, versions and embeddings + // individually. + { + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 60); + assert_eq!(pl.1.len(), 60); + assert_eq!(pl.2.len(), 120); + for point in 41..=100 { + assert_eq!(pl.0[point - 41], point as u32); + assert_eq!(pl.1[point - 41], 1); + assert_eq!(pl.2[(point - 41) * 2], point as f32); + assert_eq!(pl.2[(point - 41) * 2 + 1], point as f32); + } + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 60); + assert_eq!(pl.1.len(), 60); + assert_eq!(pl.2.len(), 120); + for point in 41..=100 { + assert_eq!(pl.0[point - 41], 100 + point as u32); + assert_eq!(pl.1[point - 41], 1); + assert_eq!(pl.2[(point - 41) * 2], 1000.0 + point as f32); + assert_eq!(pl.2[(point - 41) * 2 + 1], 1000.0 + point as f32); + } + } + } + + #[tokio::test] + async fn test_merge() { + // Insert a few entries in a couple of centers. Delete a few + // still keeping within the merge threshold. + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Insert a couple of centers. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + } + { + let pl_guard = writer.posting_list_writer.lock().await; + let mut doc_offset_ids = vec![0u32; 100]; + let mut doc_versions = vec![0; 100]; + let mut doc_embeddings = vec![0.0; 200]; + // Insert 100 points in each of the centers. + for point in 1..=100 { + doc_offset_ids[point - 1] = point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = point as f32; + doc_embeddings[(point - 1) * 2 + 1] = point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 1, &pl) + .await + .expect("Error writing to posting list"); + for point in 1..=100 { + doc_offset_ids[point - 1] = 100 + point as u32; + doc_versions[point - 1] = 1; + doc_embeddings[(point - 1) * 2] = 1000.0 + point as f32; + doc_embeddings[(point - 1) * 2 + 1] = 1000.0 + point as f32; + } + let pl = SpannPostingList { + doc_offset_ids: &doc_offset_ids, + doc_versions: &doc_versions, + doc_embeddings: &doc_embeddings, + }; + pl_guard + .set("", 2, &pl) + .await + .expect("Error writing to posting list"); + } + // Insert the points in the version map as well. + { + let mut version_map_guard = writer.versions_map.write(); + for point in 1..=100 { + version_map_guard.versions_map.insert(point as u32, 1); + version_map_guard.versions_map.insert(100 + point as u32, 1); + } + } + // Delete 60 points each from the centers. Since merge_threshold is 40, this should + // trigger a merge between the two centers. + for point in 1..=60 { + writer + .delete(point) + .await + .expect("Error deleting from spann index writer"); + writer + .delete(100 + point) + .await + .expect("Error deleting from spann index writer"); + } + // Just one more point from the latter center. + writer + .delete(100 + 61) + .await + .expect("Error deleting from spann index writer"); + // Expect the version map to be properly updated. + { + let version_map_guard = writer.versions_map.read(); + for point in 1..=60 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&0)); + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); + } + // For the other 60 points, the version should be 1. + for point in 61..=100 { + assert_eq!(version_map_guard.versions_map.get(&point), Some(&1)); + if point == 61 { + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&0)); + } else { + assert_eq!(version_map_guard.versions_map.get(&(100 + point)), Some(&1)); + } + } + } + { + // The posting lists should not be changed at all. + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + } + // Now garbage collect. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // Expect only one center now. [0.0, 0.0] + { + let hnsw_read_guard = writer.hnsw_index.inner.read(); + assert_eq!(hnsw_read_guard.len(), 1); + let (non_deleted_ids, deleted_ids) = hnsw_read_guard + .get_all_ids() + .expect("Error getting all ids"); + assert_eq!(non_deleted_ids.len(), 1); + assert_eq!(deleted_ids.len(), 1); + assert_eq!(non_deleted_ids[0], 1); + assert_eq!(deleted_ids[0], 2); + let emb = hnsw_read_guard + .get(non_deleted_ids[0]) + .expect("Error getting hnsw index") + .unwrap(); + assert_eq!(emb, &[0.0, 0.0]); + } + // Expect the posting lists with id 1 to be 79. + { + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 79); + assert_eq!(pl.1.len(), 79); + assert_eq!(pl.2.len(), 158); + } + } + + #[tokio::test] + async fn test_reassign() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Create three centers with ill placed points. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(3, &[10000.0, 10000.0]) + .expect("Error adding to hnsw index"); + } + // Insert 50 points within a radius of 1 to center 1. + let mut split_doc_offset_ids1 = vec![0u32; 50]; + let mut split_doc_versions1 = vec![0u32; 50]; + let mut split_doc_embeddings1 = vec![0.0; 100]; + let mut split_doc_offset_ids2 = vec![0u32; 50]; + let mut split_doc_versions2 = vec![0u32; 50]; + let mut split_doc_embeddings2 = vec![0.0; 100]; + let mut split_doc_offset_ids3 = vec![0u32; 50]; + let mut split_doc_versions3 = vec![0u32; 50]; + let mut split_doc_embeddings3 = vec![0.0; 100]; + { + let mut rng = rand::thread_rng(); + let pl_guard = writer.posting_list_writer.lock().await; + for i in 1..=50 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos(); + let y = r * theta.sin(); + + split_doc_offset_ids1[i - 1] = i as u32; + split_doc_versions1[i - 1] = 1; + split_doc_embeddings1[(i - 1) * 2] = x; + split_doc_embeddings1[(i - 1) * 2 + 1] = y; + } + let posting_list = SpannPostingList { + doc_offset_ids: &split_doc_offset_ids1, + doc_versions: &split_doc_versions1, + doc_embeddings: &split_doc_embeddings1, + }; + pl_guard + .set("", 1, &posting_list) + .await + .expect("Error writing to posting list"); + // Insert 50 points within a radius of 1 to center 3 to center 2 and vice versa. + // This ensures that we test reassignment and that it shuffles the two fully. + for i in 1..=50 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 1000.0; + let y = r * theta.sin() + 1000.0; + + split_doc_offset_ids3[i - 1] = 50 + i as u32; + split_doc_versions3[i - 1] = 1; + split_doc_embeddings3[(i - 1) * 2] = x; + split_doc_embeddings3[(i - 1) * 2 + 1] = y; + } + let posting_list = SpannPostingList { + doc_offset_ids: &split_doc_offset_ids3, + doc_versions: &split_doc_versions3, + doc_embeddings: &split_doc_embeddings3, + }; + pl_guard + .set("", 3, &posting_list) + .await + .expect("Error writing to posting list"); + // Do the same for 10000. + for i in 1..=50 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 10000.0; + let y = r * theta.sin() + 10000.0; + + split_doc_offset_ids2[i - 1] = 100 + i as u32; + split_doc_versions2[i - 1] = 1; + split_doc_embeddings2[(i - 1) * 2] = x; + split_doc_embeddings2[(i - 1) * 2 + 1] = y; + } + let posting_list = SpannPostingList { + doc_offset_ids: &split_doc_offset_ids2, + doc_versions: &split_doc_versions2, + doc_embeddings: &split_doc_embeddings2, + }; + pl_guard + .set("", 2, &posting_list) + .await + .expect("Error writing to posting list"); + } + // Insert these 150 points to version map. + { + let mut version_map_guard = writer.versions_map.write(); + for i in 1..=150 { + version_map_guard.versions_map.insert(i as u32, 1); + } + } + // Trigger reassign and see the results. + // Carefully construct the old head embedding so that NPA + // is violated for the second center. + writer + .collect_and_reassign( + &[1, 2], + &[Some(&vec![0.0, 0.0]), Some(&vec![1000.0, 1000.0])], + &[5000.0, 5000.0], + &[split_doc_offset_ids1.clone(), split_doc_offset_ids2.clone()], + &[split_doc_versions1.clone(), split_doc_versions2.clone()], + &[split_doc_embeddings1.clone(), split_doc_embeddings2.clone()], + ) + .await + .expect("Expected reassign to succeed"); + // See the reassigned points. + { + let pl_guard = writer.posting_list_writer.lock().await; + // Center 1 should remain unchanged. + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 50); + assert_eq!(pl.1.len(), 50); + assert_eq!(pl.2.len(), 100); + for i in 1..=50 { + assert_eq!(pl.0[i - 1], i as u32); + assert_eq!(pl.1[i - 1], 1); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings1[(i - 1) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings1[(i - 1) * 2 + 1] + ); + } + // Center 2 should get 50 points, all with version 2 migrating from center 3. + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 50); + assert_eq!(pl.1.len(), 50); + assert_eq!(pl.2.len(), 100); + for i in 1..=50 { + assert_eq!(pl.0[i - 1], 50 + i as u32); + assert_eq!(pl.1[i - 1], 2); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings3[(i - 1) * 2 + 1] + ); + } + // Center 3 should get 100 points. 50 points with version 1 which weere + // originally in center 3 and 50 points with version 2 which were originally + // in center 2. + let pl = pl_guard + .get_owned::>("", 3) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + for i in 1..=100 { + assert_eq!(pl.0[i - 1], 50 + i as u32); + if i <= 50 { + assert_eq!(pl.1[i - 1], 1); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings3[(i - 1) * 2 + 1] + ); + } else { + assert_eq!(pl.1[i - 1], 2); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings2[(i - 51) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings2[(i - 51) * 2 + 1] + ); + } + } + } + } + + #[tokio::test] + async fn test_reassign_merge() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Create three centers. 2 of these are accurate wrt their centers and third + // is ill placed. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(3, &[10000.0, 10000.0]) + .expect("Error adding to hnsw index"); + } + let mut doc_offset_ids1 = vec![0u32; 70]; + let mut doc_versions1 = vec![0u32; 70]; + let mut doc_embeddings1 = vec![0.0; 140]; + let mut doc_offset_ids2 = vec![0u32; 20]; + let mut doc_versions2 = vec![0u32; 20]; + let mut doc_embeddings2 = vec![0.0; 40]; + let mut doc_offset_ids3 = vec![0u32; 70]; + let mut doc_versions3 = vec![0u32; 70]; + let mut doc_embeddings3 = vec![0.0; 140]; + { + let mut rng = rand::thread_rng(); + let pl_guard = writer.posting_list_writer.lock().await; + // Insert 70 points within a radius of 1 to center 1. + for i in 1..=70 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos(); + let y = r * theta.sin(); + + doc_offset_ids1[i - 1] = i as u32; + doc_versions1[i - 1] = 1; + doc_embeddings1[(i - 1) * 2] = x; + doc_embeddings1[(i - 1) * 2 + 1] = y; + } + // Insert 20 points within a radius of 1 to center 2. + for i in 71..=90 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 10000.0; + let y = r * theta.sin() + 10000.0; + + doc_offset_ids2[i - 71] = i as u32; + doc_versions2[i - 71] = 1; + doc_embeddings2[(i - 71) * 2] = x; + doc_embeddings2[(i - 71) * 2 + 1] = y; + } + // Insert 70 points within a radius of 1 to center 3. + for i in 91..=160 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 10000.0; + let y = r * theta.sin() + 10000.0; + + doc_offset_ids3[i - 91] = i as u32; + doc_versions3[i - 91] = 1; + doc_embeddings3[(i - 91) * 2] = x; + doc_embeddings3[(i - 91) * 2 + 1] = y; + } + let spann_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids1, + doc_versions: &doc_versions1, + doc_embeddings: &doc_embeddings1, + }; + pl_guard + .set("", 1, &spann_posting_list) + .await + .expect("Error writing to posting list"); + let spann_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids2, + doc_versions: &doc_versions2, + doc_embeddings: &doc_embeddings2, + }; + pl_guard + .set("", 2, &spann_posting_list) + .await + .expect("Error writing to posting list"); + let spann_posting_list = SpannPostingList { + doc_offset_ids: &doc_offset_ids3, + doc_versions: &doc_versions3, + doc_embeddings: &doc_embeddings3, + }; + pl_guard + .set("", 3, &spann_posting_list) + .await + .expect("Error writing to posting list"); + } + // Initialize the versions map appropriately. + { + let mut version_map_guard = writer.versions_map.write(); + for i in 1..=160 { + version_map_guard.versions_map.insert(i as u32, 1); + } + } + // Run a GC now. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // Run GC again to clean up the outdated points. + writer + .garbage_collect() + .await + .expect("Error garbage collecting"); + // check the posting lists. + { + let pl_guard = writer.posting_list_writer.lock().await; + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 70); + assert_eq!(pl.1.len(), 70); + assert_eq!(pl.2.len(), 140); + for point in 1..=70 { + assert_eq!(pl.0[point - 1], point as u32); + assert_eq!(pl.1[point - 1], 1); + assert_eq!(pl.2[(point - 1) * 2], doc_embeddings1[(point - 1) * 2]); + assert_eq!( + pl.2[(point - 1) * 2 + 1], + doc_embeddings1[(point - 1) * 2 + 1] + ); + } + let pl = pl_guard + .get_owned::>("", 3) + .await + .expect("Error getting posting list") + .unwrap(); + // PL3 should be 90. + assert_eq!(pl.0.len(), 90); + assert_eq!(pl.1.len(), 90); + assert_eq!(pl.2.len(), 180); + for point in 1..=70 { + assert_eq!(pl.0[point - 1], 90 + point as u32); + assert_eq!(pl.1[point - 1], 1); + assert_eq!(pl.2[(point - 1) * 2], doc_embeddings3[(point - 1) * 2]); + assert_eq!( + pl.2[(point - 1) * 2 + 1], + doc_embeddings3[(point - 1) * 2 + 1] + ); + } + for point in 71..=90 { + assert_eq!(pl.0[point - 1], point as u32); + assert_eq!(pl.1[point - 1], 2); + assert_eq!(pl.2[(point - 1) * 2], doc_embeddings2[(point - 71) * 2]); + assert_eq!( + pl.2[(point - 1) * 2 + 1], + doc_embeddings2[(point - 71) * 2 + 1] + ); + } + } + // There should only be two heads. + { + let hnsw_read_guard = writer.hnsw_index.inner.read(); + assert_eq!(hnsw_read_guard.len(), 2); + let (mut non_deleted_ids, deleted_ids) = hnsw_read_guard + .get_all_ids() + .expect("Error getting all ids"); + non_deleted_ids.sort(); + assert_eq!(non_deleted_ids.len(), 2); + assert_eq!(deleted_ids.len(), 1); + assert_eq!(non_deleted_ids[0], 1); + assert_eq!(non_deleted_ids[1], 3); + assert_eq!(deleted_ids[0], 2); + let emb = hnsw_read_guard + .get(non_deleted_ids[0]) + .expect("Error getting hnsw index") + .unwrap(); + assert_eq!(emb, &[0.0, 0.0]); + let emb = hnsw_read_guard + .get(non_deleted_ids[1]) + .expect("Error getting hnsw index") + .unwrap(); + assert_eq!(emb, &[10000.0, 10000.0]); + } + } } diff --git a/rust/index/src/types.rs b/rust/index/src/types.rs index bcccaed2744..c2e0cf9243e 100644 --- a/rust/index/src/types.rs +++ b/rust/index/src/types.rs @@ -44,6 +44,8 @@ pub trait Index { disallow_ids: &[usize], ) -> Result<(Vec, Vec), Box>; fn get(&self, id: usize) -> Result>, Box>; + fn get_all_ids(&self) -> Result<(Vec, Vec), Box>; + fn get_all_ids_sizes(&self) -> Result, Box>; } /// The persistent index trait.