diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index ff8b80bcbe1..6c116e5ae88 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1965,4 +1965,245 @@ mod tests { assert_eq!(pl.2.len(), 158); } } + + #[tokio::test] + async fn test_reassign() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let arrow_blockfile_provider = ArrowBlockfileProvider::new( + storage.clone(), + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + let blockfile_provider = + BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider); + let hnsw_cache = new_non_persistent_cache_for_test(); + let (_, rx) = tokio::sync::mpsc::unbounded_channel(); + let hnsw_provider = HnswIndexProvider::new( + storage.clone(), + PathBuf::from(tmp_dir.path().to_str().unwrap()), + hnsw_cache, + rx, + ); + let m = 16; + let ef_construction = 200; + let ef_search = 200; + let collection_id = CollectionUuid::new(); + let distance_function = chroma_distance::DistanceFunction::Euclidean; + let dimensionality = 2; + let writer = SpannIndexWriter::from_id( + &hnsw_provider, + None, + None, + None, + None, + Some(m), + Some(ef_construction), + Some(ef_search), + &collection_id, + distance_function, + dimensionality, + &blockfile_provider, + ) + .await + .expect("Error creating spann index writer"); + // Create three centers with ill placed points. + { + let hnsw_guard = writer.hnsw_index.inner.write(); + hnsw_guard + .add(1, &[0.0, 0.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(2, &[1000.0, 1000.0]) + .expect("Error adding to hnsw index"); + hnsw_guard + .add(3, &[10000.0, 10000.0]) + .expect("Error adding to hnsw index"); + } + // Insert 50 points within a radius of 1 to center 1. + let mut split_doc_offset_ids1 = vec![0u32; 50]; + let mut split_doc_versions1 = vec![0u32; 50]; + let mut split_doc_embeddings1 = vec![0.0; 100]; + let mut split_doc_offset_ids2 = vec![0u32; 50]; + let mut split_doc_versions2 = vec![0u32; 50]; + let mut split_doc_embeddings2 = vec![0.0; 100]; + let mut split_doc_offset_ids3 = vec![0u32; 50]; + let mut split_doc_versions3 = vec![0u32; 50]; + let mut split_doc_embeddings3 = vec![0.0; 100]; + { + let mut rng = rand::thread_rng(); + let pl_guard = writer.posting_list_writer.lock().await; + for i in 1..=50 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos(); + let y = r * theta.sin(); + + split_doc_offset_ids1[i - 1] = i as u32; + split_doc_versions1[i - 1] = 1; + split_doc_embeddings1[(i - 1) * 2] = x; + split_doc_embeddings1[(i - 1) * 2 + 1] = y; + } + let posting_list = SpannPostingList { + doc_offset_ids: &split_doc_offset_ids1, + doc_versions: &split_doc_versions1, + doc_embeddings: &split_doc_embeddings1, + }; + pl_guard + .set("", 1, &posting_list) + .await + .expect("Error writing to posting list"); + // Insert 50 points within a radius of 1 to center 3 to center 2 and vice versa. + // This ensures that we test reassignment and that it shuffles the two fully. + for i in 1..=50 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 1000.0; + let y = r * theta.sin() + 1000.0; + + split_doc_offset_ids3[i - 1] = 50 + i as u32; + split_doc_versions3[i - 1] = 1; + split_doc_embeddings3[(i - 1) * 2] = x; + split_doc_embeddings3[(i - 1) * 2 + 1] = y; + } + let posting_list = SpannPostingList { + doc_offset_ids: &split_doc_offset_ids3, + doc_versions: &split_doc_versions3, + doc_embeddings: &split_doc_embeddings3, + }; + pl_guard + .set("", 3, &posting_list) + .await + .expect("Error writing to posting list"); + // Do the same for 10000. + for i in 1..=50 { + // Generate random radius between 0 and 1 + let r = rng.gen::().sqrt(); // sqrt for uniform distribution + + // Generate random angle between 0 and 2π + let theta = rng.gen::() * 2.0 * PI; + + // Convert to Cartesian coordinates + let x = r * theta.cos() + 10000.0; + let y = r * theta.sin() + 10000.0; + + split_doc_offset_ids2[i - 1] = 100 + i as u32; + split_doc_versions2[i - 1] = 1; + split_doc_embeddings2[(i - 1) * 2] = x; + split_doc_embeddings2[(i - 1) * 2 + 1] = y; + } + let posting_list = SpannPostingList { + doc_offset_ids: &split_doc_offset_ids2, + doc_versions: &split_doc_versions2, + doc_embeddings: &split_doc_embeddings2, + }; + pl_guard + .set("", 2, &posting_list) + .await + .expect("Error writing to posting list"); + } + // Insert these 150 points to version map. + { + let mut version_map_guard = writer.versions_map.write(); + for i in 1..=150 { + version_map_guard.versions_map.insert(i as u32, 1); + } + } + // Trigger reassign and see the results. + // Carefully construct the old head embedding so that NPA + // is violated for the second center. + writer + .collect_and_reassign( + &[1, 2], + &[Some(&vec![0.0, 0.0]), Some(&vec![1000.0, 1000.0])], + &[5000.0, 5000.0], + &[split_doc_offset_ids1.clone(), split_doc_offset_ids2.clone()], + &[split_doc_versions1.clone(), split_doc_versions2.clone()], + &[split_doc_embeddings1.clone(), split_doc_embeddings2.clone()], + ) + .await + .expect("Expected reassign to succeed"); + // See the reassigned points. + { + let pl_guard = writer.posting_list_writer.lock().await; + // Center 1 should remain unchanged. + let pl = pl_guard + .get_owned::>("", 1) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 50); + assert_eq!(pl.1.len(), 50); + assert_eq!(pl.2.len(), 100); + for i in 1..=50 { + assert_eq!(pl.0[i - 1], i as u32); + assert_eq!(pl.1[i - 1], 1); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings1[(i - 1) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings1[(i - 1) * 2 + 1] + ); + } + // Center 2 should get 50 points, all with version 2 migrating from center 3. + let pl = pl_guard + .get_owned::>("", 2) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 50); + assert_eq!(pl.1.len(), 50); + assert_eq!(pl.2.len(), 100); + for i in 1..=50 { + assert_eq!(pl.0[i - 1], 50 + i as u32); + assert_eq!(pl.1[i - 1], 2); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings3[(i - 1) * 2 + 1] + ); + } + // Center 3 should get 100 points. 50 points with version 1 which weere + // originally in center 3 and 50 points with version 2 which were originally + // in center 2. + let pl = pl_guard + .get_owned::>("", 3) + .await + .expect("Error getting posting list") + .unwrap(); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + for i in 1..=100 { + assert_eq!(pl.0[i - 1], 50 + i as u32); + if i <= 50 { + assert_eq!(pl.1[i - 1], 1); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings3[(i - 1) * 2 + 1] + ); + } else { + assert_eq!(pl.1[i - 1], 2); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings2[(i - 51) * 2]); + assert_eq!( + pl.2[(i - 1) * 2 + 1], + split_doc_embeddings2[(i - 51) * 2 + 1] + ); + } + } + } + } }