Skip to content

Commit

Permalink
[ENH] Add rust hnswlib bindings, index interface (#1516)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Fixes ChromaError to Trait Bound on std Error
 - New functionality
	 - Adds the index module with traits for Index and Persistent Index types
	 - Adds bindings to chroma-hnswlib c++, along with a rust-y interface for it.
	 - Adds basic config injection for the index. In the future we can add dynamic/static field + watch behavior. I sketched out a plan for that while implementing this.

## Test plan
*How are these changes tested?*
Rudimentary unit tests.
- [x] Tests pass locally with `cargo test`

## Documentation Changes
None required.
  • Loading branch information
HammadB authored Jan 16, 2024
1 parent af37c9a commit 878f91a
Show file tree
Hide file tree
Showing 11 changed files with 827 additions and 1 deletion.
9 changes: 9 additions & 0 deletions .github/workflows/chroma-worker-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,20 @@ jobs:
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout chroma-hnswlib
uses: actions/checkout@v3
with:
repository: chroma-core/hnswlib
path: hnswlib
- name: Checkout
uses: actions/checkout@v3
with:
path: chroma
- name: Install Protoc
uses: arduino/setup-protoc@v2
- name: Build
run: cargo build --verbose
working-directory: chroma
- name: Test
run: cargo test --verbose
working-directory: chroma
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ num_cpus = "1.16.0"
murmur3 = "0.5.2"
thiserror = "1.0.50"
num-bigint = "0.4.4"
tempfile = "3.8.1"

[build-dependencies]
tonic-build = "0.10"
Expand Down
203 changes: 203 additions & 0 deletions rust/worker/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// Assumes that chroma-hnswlib is checked out at the same level as chroma
#include "../../../hnswlib/hnswlib/hnswlib.h"

template <typename dist_t, typename data_t = float>
class Index
{
public:
std::string space_name;
int dim;
size_t seed;

bool normalize;
bool index_inited;

hnswlib::HierarchicalNSW<dist_t> *appr_alg;
hnswlib::SpaceInterface<float> *l2space;

Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim)
{
if (space_name == "l2")
{
l2space = new hnswlib::L2Space(dim);
normalize = false;
}
if (space_name == "ip")
{
l2space = new hnswlib::InnerProductSpace(dim);
// For IP, we expect the vectors to be normalized
normalize = false;
}
if (space_name == "cosine")
{
l2space = new hnswlib::InnerProductSpace(dim);
normalize = true;
}
appr_alg = NULL;
index_inited = false;
}

~Index()
{
delete l2space;
if (appr_alg)
{
delete appr_alg;
}
}

void init_index(const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const std::string &persistence_location)
{
if (index_inited)
{
std::runtime_error("Index already inited");
}
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, max_elements, M, ef_construction, random_seed, allow_replace_deleted, normalize, is_persistent_index, persistence_location);
appr_alg->ef_ = 10; // This is a default value for ef_
index_inited = true;
}

void load_index(const std::string &path_to_index, const bool allow_replace_deleted, const bool is_persistent_index)
{
if (index_inited)
{
std::runtime_error("Index already inited");
}
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, path_to_index, false, 0, allow_replace_deleted, normalize, is_persistent_index);
index_inited = true;
}

void persist_dirty()
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
appr_alg->persistDirty();
}

void add_item(const data_t *data, const hnswlib::labeltype id, const bool replace_deleted = false)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
appr_alg->addPoint(data, id);
}

void get_item(const hnswlib::labeltype id, data_t *data)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
std::vector<data_t> ret_data = appr_alg->template getDataByLabel<data_t>(id); // This checks if id is deleted
for (int i = 0; i < dim; i++)
{
data[i] = ret_data[i];
}
}

int mark_deleted(const hnswlib::labeltype id)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
appr_alg->markDelete(id);
return 0;
}

void knn_query(const data_t *query_vector, const size_t k, hnswlib::labeltype *ids, data_t *distance)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
std::priority_queue<std::pair<dist_t, hnswlib::labeltype>> res = appr_alg->searchKnn(query_vector, k);
if (res.size() < k)
{
// TODO: This is ok and we should return < K results, but for maintining compatibility with the old API we throw an error for now
std::runtime_error("Not enough results");
}
int total_results = std::min(res.size(), k);
for (int i = total_results - 1; i >= 0; i--)
{
std::pair<dist_t, hnswlib::labeltype> res_i = res.top();
ids[i] = res_i.second;
distance[i] = res_i.first;
res.pop();
}
}

int get_ef()
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
return appr_alg->ef_;
}

void set_ef(const size_t ef)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
appr_alg->ef_ = ef;
}
};

extern "C"
{
Index<float> *create_index(const char *space_name, const int dim)
{
return new Index<float>(space_name, dim);
}

void init_index(Index<float> *index, const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const char *persistence_location)
{
index->init_index(max_elements, M, ef_construction, random_seed, allow_replace_deleted, is_persistent_index, persistence_location);
}

void load_index(Index<float> *index, const char *path_to_index, const bool allow_replace_deleted, const bool is_persistent_index)
{
index->load_index(path_to_index, allow_replace_deleted, is_persistent_index);
}

void persist_dirty(Index<float> *index)
{
index->persist_dirty();
}

void add_item(Index<float> *index, const float *data, const hnswlib::labeltype id, const bool replace_deleted)
{
index->add_item(data, id);
}

void get_item(Index<float> *index, const hnswlib::labeltype id, float *data)
{
index->get_item(id, data);
}

int mark_deleted(Index<float> *index, const hnswlib::labeltype id)
{
return index->mark_deleted(id);
}

void knn_query(Index<float> *index, const float *query_vector, const size_t k, hnswlib::labeltype *ids, float *distance)
{
index->knn_query(query_vector, k, ids, distance);
}

int get_ef(Index<float> *index)
{
return index->appr_alg->ef_;
}

void set_ef(Index<float> *index, const size_t ef)
{
index->set_ef(ef);
}
}
13 changes: 13 additions & 0 deletions rust/worker/build.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Compile the protobuf files in the chromadb proto directory.
tonic_build::configure().compile(
&[
"../../idl/chromadb/proto/chroma.proto",
"../../idl/chromadb/proto/coordinator.proto",
],
&["../../idl/"],
)?;

// Compile the hnswlib bindings.
cc::Build::new()
.cpp(true)
.file("bindings.cpp")
.flag("-std=c++11")
.flag("-Ofast")
.flag("-DHAVE_CXX0X")
.flag("-fpic")
.flag("-ftree-vectorize")
.compile("bindings");

Ok(())
}
4 changes: 3 additions & 1 deletion rust/worker/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// gRPC spec. https://grpc.github.io/grpc/core/md_doc_statuscodes.html
// Custom errors can use these codes in order to allow for generic handling

use std::error::Error;

pub(crate) enum ErrorCodes {
// OK is returned on success, we use "Success" since Ok is a keyword in Rust.
Success = 0,
Expand Down Expand Up @@ -39,6 +41,6 @@ pub(crate) enum ErrorCodes {
DataLoss = 15,
}

pub(crate) trait ChromaError {
pub(crate) trait ChromaError: Error {
fn code(&self) -> ErrorCodes;
}
Loading

0 comments on commit 878f91a

Please sign in to comment.