Skip to content

Commit

Permalink
[ENH] Full garbage collection (#3194)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
	 - Full garbage collection in spann

## Test plan
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia authored Dec 5, 2024
1 parent 8b63966 commit 1caa19f
Show file tree
Hide file tree
Showing 4 changed files with 1,253 additions and 2 deletions.
56 changes: 56 additions & 0 deletions rust/index/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,34 @@ class Index
}
}

void get_all_ids_sizes(size_t *ids_sizes)
{
if (!index_inited)
{
throw std::runtime_error("Index not inited");
}
auto res = appr_alg->getLabelCounts();
ids_sizes[0] = res.first;
ids_sizes[1] = res.second;
}

void get_all_ids(hnswlib::labeltype *non_deleted_ids, hnswlib::labeltype *deleted_ids)
{
if (!index_inited)
{
throw std::runtime_error("Index not inited");
}
auto res = appr_alg->getAllLabels();
for (int i = 0; i < res.first.size(); i++)
{
non_deleted_ids[i] = res.first[i];
}
for (int i = 0; i < res.second.size(); i++)
{
deleted_ids[i] = res.second[i];
}
}

void mark_deleted(const hnswlib::labeltype id)
{
if (!index_inited)
Expand Down Expand Up @@ -307,6 +335,34 @@ extern "C"
last_error.clear();
}

void get_all_ids_sizes(Index<float> *index, size_t *ids_sizes)
{
try
{
index->get_all_ids_sizes(ids_sizes);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

void get_all_ids(Index<float> *index, hnswlib::labeltype *non_deleted_ids, hnswlib::labeltype *deleted_ids)
{
try
{
index->get_all_ids(non_deleted_ids, deleted_ids);
}
catch (std::exception &e)
{
last_error = e.what();
return;
}
last_error.clear();
}

// Can throw std::exception
void mark_deleted(Index<float> *index, const hnswlib::labeltype id)
{
Expand Down
24 changes: 24 additions & 0 deletions rust/index/src/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,28 @@ impl Index<HnswIndexConfig> for HnswIndex {
Ok(Some(data))
}
}

fn get_all_ids_sizes(&self) -> Result<Vec<usize>, Box<dyn ChromaError>> {
let mut sizes = vec![0usize; 2];
unsafe { get_all_ids_sizes(self.ffi_ptr, sizes.as_mut_ptr()) };
read_and_return_hnsw_error(self.ffi_ptr)?;
Ok(sizes)
}

fn get_all_ids(&self) -> Result<(Vec<usize>, Vec<usize>), Box<dyn ChromaError>> {
let sizes = self.get_all_ids_sizes()?;
let mut non_deleted_ids = vec![0usize; sizes[0]];
let mut deleted_ids = vec![0usize; sizes[1]];
unsafe {
get_all_ids(
self.ffi_ptr,
non_deleted_ids.as_mut_ptr(),
deleted_ids.as_mut_ptr(),
);
}
read_and_return_hnsw_error(self.ffi_ptr)?;
Ok((non_deleted_ids, deleted_ids))
}
}

impl PersistentIndex<HnswIndexConfig> for HnswIndex {
Expand Down Expand Up @@ -359,6 +381,8 @@ extern "C" {
fn add_item(index: *const IndexPtrFFI, data: *const f32, id: usize, replace_deleted: bool);
fn mark_deleted(index: *const IndexPtrFFI, id: usize);
fn get_item(index: *const IndexPtrFFI, id: usize, data: *mut f32);
fn get_all_ids_sizes(index: *const IndexPtrFFI, sizes: *mut usize);
fn get_all_ids(index: *const IndexPtrFFI, non_deleted_ids: *mut usize, deleted_ids: *mut usize);
fn knn_query(
index: *const IndexPtrFFI,
query_vector: *const f32,
Expand Down
Loading

0 comments on commit 1caa19f

Please sign in to comment.