Skip to content

Commit

Permalink
Add merge reassign test
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Dec 4, 2024
1 parent 3ecda71 commit 54e68c3
Showing 1 changed file with 242 additions and 2 deletions.
244 changes: 242 additions & 2 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ impl SpannIndexWriter {
let (nearest_head_ids, _, nearest_head_embeddings) = self
.get_nearby_heads(head_embedding, NUM_CENTERS_TO_MERGE_TO)
.await?;
for (nearest_head_id, head_embedding) in nearest_head_ids
for (nearest_head_id, nearest_head_embedding) in nearest_head_ids
.into_iter()
.zip(nearest_head_embeddings.into_iter())
{
Expand Down Expand Up @@ -1183,7 +1183,7 @@ impl SpannIndexWriter {
}
// This center is now merged with a neighbor.
target_head = nearest_head_id;
target_embedding = head_embedding;
target_embedding = nearest_head_embedding;
merged_with_a_nbr = true;
break;
}
Expand Down Expand Up @@ -2206,4 +2206,244 @@ mod tests {
}
}
}

#[tokio::test]
async fn test_reassign_merge() {
let tmp_dir = tempfile::tempdir().unwrap();
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = new_cache_for_test();
let sparse_index_cache = new_cache_for_test();
let arrow_blockfile_provider = ArrowBlockfileProvider::new(
storage.clone(),
TEST_MAX_BLOCK_SIZE_BYTES,
block_cache,
sparse_index_cache,
);
let blockfile_provider =
BlockfileProvider::ArrowBlockfileProvider(arrow_blockfile_provider);
let hnsw_cache = new_non_persistent_cache_for_test();
let (_, rx) = tokio::sync::mpsc::unbounded_channel();
let hnsw_provider = HnswIndexProvider::new(
storage.clone(),
PathBuf::from(tmp_dir.path().to_str().unwrap()),
hnsw_cache,
rx,
);
let m = 16;
let ef_construction = 200;
let ef_search = 200;
let collection_id = CollectionUuid::new();
let distance_function = chroma_distance::DistanceFunction::Euclidean;
let dimensionality = 2;
let writer = SpannIndexWriter::from_id(
&hnsw_provider,
None,
None,
None,
None,
Some(m),
Some(ef_construction),
Some(ef_search),
&collection_id,
distance_function,
dimensionality,
&blockfile_provider,
)
.await
.expect("Error creating spann index writer");
// Create three centers. 2 of these are accurate wrt their centers and third
// is ill placed.
{
let hnsw_guard = writer.hnsw_index.inner.write();
hnsw_guard
.add(1, &[0.0, 0.0])
.expect("Error adding to hnsw index");
hnsw_guard
.add(2, &[1000.0, 1000.0])
.expect("Error adding to hnsw index");
hnsw_guard
.add(3, &[10000.0, 10000.0])
.expect("Error adding to hnsw index");
}
let mut doc_offset_ids1 = vec![0u32; 70];
let mut doc_versions1 = vec![0u32; 70];
let mut doc_embeddings1 = vec![0.0; 140];
let mut doc_offset_ids2 = vec![0u32; 20];
let mut doc_versions2 = vec![0u32; 20];
let mut doc_embeddings2 = vec![0.0; 40];
let mut doc_offset_ids3 = vec![0u32; 70];
let mut doc_versions3 = vec![0u32; 70];
let mut doc_embeddings3 = vec![0.0; 140];
{
let mut rng = rand::thread_rng();
let pl_guard = writer.posting_list_writer.lock().await;
// Insert 70 points within a radius of 1 to center 1.
for i in 1..=70 {
// Generate random radius between 0 and 1
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution

// Generate random angle between 0 and 2π
let theta = rng.gen::<f32>() * 2.0 * PI;

// Convert to Cartesian coordinates
let x = r * theta.cos();
let y = r * theta.sin();

doc_offset_ids1[i - 1] = i as u32;
doc_versions1[i - 1] = 1;
doc_embeddings1[(i - 1) * 2] = x;
doc_embeddings1[(i - 1) * 2 + 1] = y;
}
// Insert 20 points within a radius of 1 to center 2.
for i in 71..=90 {
// Generate random radius between 0 and 1
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution

// Generate random angle between 0 and 2π
let theta = rng.gen::<f32>() * 2.0 * PI;

// Convert to Cartesian coordinates
let x = r * theta.cos() + 10000.0;
let y = r * theta.sin() + 10000.0;

doc_offset_ids2[i - 71] = i as u32;
doc_versions2[i - 71] = 1;
doc_embeddings2[(i - 71) * 2] = x;
doc_embeddings2[(i - 71) * 2 + 1] = y;
}
// Insert 70 points within a radius of 1 to center 3.
for i in 91..=160 {
// Generate random radius between 0 and 1
let r = rng.gen::<f32>().sqrt(); // sqrt for uniform distribution

// Generate random angle between 0 and 2π
let theta = rng.gen::<f32>() * 2.0 * PI;

// Convert to Cartesian coordinates
let x = r * theta.cos() + 10000.0;
let y = r * theta.sin() + 10000.0;

doc_offset_ids3[i - 91] = i as u32;
doc_versions3[i - 91] = 1;
doc_embeddings3[(i - 91) * 2] = x;
doc_embeddings3[(i - 91) * 2 + 1] = y;
}
let spann_posting_list = SpannPostingList {
doc_offset_ids: &doc_offset_ids1,
doc_versions: &doc_versions1,
doc_embeddings: &doc_embeddings1,
};
pl_guard
.set("", 1, &spann_posting_list)
.await
.expect("Error writing to posting list");
let spann_posting_list = SpannPostingList {
doc_offset_ids: &doc_offset_ids2,
doc_versions: &doc_versions2,
doc_embeddings: &doc_embeddings2,
};
pl_guard
.set("", 2, &spann_posting_list)
.await
.expect("Error writing to posting list");
let spann_posting_list = SpannPostingList {
doc_offset_ids: &doc_offset_ids3,
doc_versions: &doc_versions3,
doc_embeddings: &doc_embeddings3,
};
pl_guard
.set("", 3, &spann_posting_list)
.await
.expect("Error writing to posting list");
}
// Initialize the versions map appropriately.
{
let mut version_map_guard = writer.versions_map.write();
for i in 1..=160 {
version_map_guard.versions_map.insert(i as u32, 1);
}
}
// Run a GC now.
writer
.garbage_collect()
.await
.expect("Error garbage collecting");
// Run GC again to clean up the outdated points.
writer
.garbage_collect()
.await
.expect("Error garbage collecting");
// check the posting lists.
{
let pl_guard = writer.posting_list_writer.lock().await;
let pl = pl_guard
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 70);
assert_eq!(pl.1.len(), 70);
assert_eq!(pl.2.len(), 140);
for point in 1..=70 {
assert_eq!(pl.0[point - 1], point as u32);
assert_eq!(pl.1[point - 1], 1);
assert_eq!(pl.2[(point - 1) * 2], doc_embeddings1[(point - 1) * 2]);
assert_eq!(
pl.2[(point - 1) * 2 + 1],
doc_embeddings1[(point - 1) * 2 + 1]
);
}
let pl = pl_guard
.get_owned::<u32, &SpannPostingList<'_>>("", 3)
.await
.expect("Error getting posting list")
.unwrap();
// PL3 should be 90.
assert_eq!(pl.0.len(), 90);
assert_eq!(pl.1.len(), 90);
assert_eq!(pl.2.len(), 180);
for point in 1..=70 {
assert_eq!(pl.0[point - 1], 90 + point as u32);
assert_eq!(pl.1[point - 1], 1);
assert_eq!(pl.2[(point - 1) * 2], doc_embeddings3[(point - 1) * 2]);
assert_eq!(
pl.2[(point - 1) * 2 + 1],
doc_embeddings3[(point - 1) * 2 + 1]
);
}
for point in 71..=90 {
assert_eq!(pl.0[point - 1], point as u32);
assert_eq!(pl.1[point - 1], 2);
assert_eq!(pl.2[(point - 1) * 2], doc_embeddings2[(point - 71) * 2]);
assert_eq!(
pl.2[(point - 1) * 2 + 1],
doc_embeddings2[(point - 71) * 2 + 1]
);
}
}
// There should only be two heads.
{
let hnsw_read_guard = writer.hnsw_index.inner.read();
assert_eq!(hnsw_read_guard.len(), 2);
let (mut non_deleted_ids, deleted_ids) = hnsw_read_guard
.get_all_ids()
.expect("Error getting all ids");
non_deleted_ids.sort();
assert_eq!(non_deleted_ids.len(), 2);
assert_eq!(deleted_ids.len(), 1);
assert_eq!(non_deleted_ids[0], 1);
assert_eq!(non_deleted_ids[1], 3);
assert_eq!(deleted_ids[0], 2);
let emb = hnsw_read_guard
.get(non_deleted_ids[0])
.expect("Error getting hnsw index")
.unwrap();
assert_eq!(emb, &[0.0, 0.0]);
let emb = hnsw_read_guard
.get(non_deleted_ids[1])
.expect("Error getting hnsw index")
.unwrap();
assert_eq!(emb, &[10000.0, 10000.0]);
}
}
}

0 comments on commit 54e68c3

Please sign in to comment.