From 878f91a4442d5f5b3aa7bb80c54e70b044620277 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 15 Jan 2024 17:26:40 -0800 Subject: [PATCH] [ENH] Add rust hnswlib bindings, index interface (#1516) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Fixes ChromaError to Trait Bound on std Error - New functionality - Adds the index module with traits for Index and Persistent Index types - Adds bindings to chroma-hnswlib c++, along with a rust-y interface for it. - Adds basic config injection for the index. In the future we can add dynamic/static field + watch behavior. I sketched out a plan for that while implementing this. ## Test plan *How are these changes tested?* Rudimentary unit tests. - [x] Tests pass locally with `cargo test` ## Documentation Changes None required. --- .github/workflows/chroma-worker-test.yml | 9 + Cargo.lock | 1 + rust/worker/Cargo.toml | 1 + rust/worker/bindings.cpp | 203 ++++++++++ rust/worker/build.rs | 13 + rust/worker/src/errors.rs | 4 +- rust/worker/src/index/hnsw.rs | 479 +++++++++++++++++++++++ rust/worker/src/index/mod.rs | 6 + rust/worker/src/index/types.rs | 98 +++++ rust/worker/src/index/utils.rs | 13 + rust/worker/src/lib.rs | 1 + 11 files changed, 827 insertions(+), 1 deletion(-) create mode 100644 rust/worker/bindings.cpp create mode 100644 rust/worker/src/index/hnsw.rs create mode 100644 rust/worker/src/index/mod.rs create mode 100644 rust/worker/src/index/types.rs create mode 100644 rust/worker/src/index/utils.rs diff --git a/.github/workflows/chroma-worker-test.yml b/.github/workflows/chroma-worker-test.yml index 2cfce1b6d4a..33e1012e0c8 100644 --- a/.github/workflows/chroma-worker-test.yml +++ b/.github/workflows/chroma-worker-test.yml @@ -17,11 +17,20 @@ jobs: platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} steps: + - name: Checkout chroma-hnswlib + uses: actions/checkout@v3 + with: + repository: chroma-core/hnswlib + path: hnswlib - name: Checkout uses: actions/checkout@v3 + with: + path: chroma - name: Install Protoc uses: arduino/setup-protoc@v2 - name: Build run: cargo build --verbose + working-directory: chroma - name: Test run: cargo test --verbose + working-directory: chroma diff --git a/Cargo.lock b/Cargo.lock index 8077c626d8d..a65f55f113d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1490,6 +1490,7 @@ dependencies = [ "rand", "rayon", "serde", + "tempfile", "thiserror", "tokio", "tokio-util", diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index c1c2776078a..a304c2288b4 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -19,6 +19,7 @@ num_cpus = "1.16.0" murmur3 = "0.5.2" thiserror = "1.0.50" num-bigint = "0.4.4" +tempfile = "3.8.1" [build-dependencies] tonic-build = "0.10" diff --git a/rust/worker/bindings.cpp b/rust/worker/bindings.cpp new file mode 100644 index 00000000000..982d14dd5d8 --- /dev/null +++ b/rust/worker/bindings.cpp @@ -0,0 +1,203 @@ +// Assumes that chroma-hnswlib is checked out at the same level as chroma +#include "../../../hnswlib/hnswlib/hnswlib.h" + +template +class Index +{ +public: + std::string space_name; + int dim; + size_t seed; + + bool normalize; + bool index_inited; + + hnswlib::HierarchicalNSW *appr_alg; + hnswlib::SpaceInterface *l2space; + + Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) + { + if (space_name == "l2") + { + l2space = new hnswlib::L2Space(dim); + normalize = false; + } + if (space_name == "ip") + { + l2space = new hnswlib::InnerProductSpace(dim); + // For IP, we expect the vectors to be normalized + normalize = false; + } + if (space_name == "cosine") + { + l2space = new hnswlib::InnerProductSpace(dim); + normalize = true; + } + appr_alg = NULL; + index_inited = false; + } + + ~Index() + { + delete l2space; + if (appr_alg) + { + delete appr_alg; + } + } + + void init_index(const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const std::string &persistence_location) + { + if (index_inited) + { + std::runtime_error("Index already inited"); + } + appr_alg = new hnswlib::HierarchicalNSW(l2space, max_elements, M, ef_construction, random_seed, allow_replace_deleted, normalize, is_persistent_index, persistence_location); + appr_alg->ef_ = 10; // This is a default value for ef_ + index_inited = true; + } + + void load_index(const std::string &path_to_index, const bool allow_replace_deleted, const bool is_persistent_index) + { + if (index_inited) + { + std::runtime_error("Index already inited"); + } + appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, 0, allow_replace_deleted, normalize, is_persistent_index); + index_inited = true; + } + + void persist_dirty() + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + appr_alg->persistDirty(); + } + + void add_item(const data_t *data, const hnswlib::labeltype id, const bool replace_deleted = false) + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + appr_alg->addPoint(data, id); + } + + void get_item(const hnswlib::labeltype id, data_t *data) + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + std::vector ret_data = appr_alg->template getDataByLabel(id); // This checks if id is deleted + for (int i = 0; i < dim; i++) + { + data[i] = ret_data[i]; + } + } + + int mark_deleted(const hnswlib::labeltype id) + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + appr_alg->markDelete(id); + return 0; + } + + void knn_query(const data_t *query_vector, const size_t k, hnswlib::labeltype *ids, data_t *distance) + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + std::priority_queue> res = appr_alg->searchKnn(query_vector, k); + if (res.size() < k) + { + // TODO: This is ok and we should return < K results, but for maintining compatibility with the old API we throw an error for now + std::runtime_error("Not enough results"); + } + int total_results = std::min(res.size(), k); + for (int i = total_results - 1; i >= 0; i--) + { + std::pair res_i = res.top(); + ids[i] = res_i.second; + distance[i] = res_i.first; + res.pop(); + } + } + + int get_ef() + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + return appr_alg->ef_; + } + + void set_ef(const size_t ef) + { + if (!index_inited) + { + std::runtime_error("Index not inited"); + } + appr_alg->ef_ = ef; + } +}; + +extern "C" +{ + Index *create_index(const char *space_name, const int dim) + { + return new Index(space_name, dim); + } + + void init_index(Index *index, const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const char *persistence_location) + { + index->init_index(max_elements, M, ef_construction, random_seed, allow_replace_deleted, is_persistent_index, persistence_location); + } + + void load_index(Index *index, const char *path_to_index, const bool allow_replace_deleted, const bool is_persistent_index) + { + index->load_index(path_to_index, allow_replace_deleted, is_persistent_index); + } + + void persist_dirty(Index *index) + { + index->persist_dirty(); + } + + void add_item(Index *index, const float *data, const hnswlib::labeltype id, const bool replace_deleted) + { + index->add_item(data, id); + } + + void get_item(Index *index, const hnswlib::labeltype id, float *data) + { + index->get_item(id, data); + } + + int mark_deleted(Index *index, const hnswlib::labeltype id) + { + return index->mark_deleted(id); + } + + void knn_query(Index *index, const float *query_vector, const size_t k, hnswlib::labeltype *ids, float *distance) + { + index->knn_query(query_vector, k, ids, distance); + } + + int get_ef(Index *index) + { + return index->appr_alg->ef_; + } + + void set_ef(Index *index, const size_t ef) + { + index->set_ef(ef); + } +} diff --git a/rust/worker/build.rs b/rust/worker/build.rs index 78b226a0e0c..315f75d381b 100644 --- a/rust/worker/build.rs +++ b/rust/worker/build.rs @@ -1,4 +1,5 @@ fn main() -> Result<(), Box> { + // Compile the protobuf files in the chromadb proto directory. tonic_build::configure().compile( &[ "../../idl/chromadb/proto/chroma.proto", @@ -6,5 +7,17 @@ fn main() -> Result<(), Box> { ], &["../../idl/"], )?; + + // Compile the hnswlib bindings. + cc::Build::new() + .cpp(true) + .file("bindings.cpp") + .flag("-std=c++11") + .flag("-Ofast") + .flag("-DHAVE_CXX0X") + .flag("-fpic") + .flag("-ftree-vectorize") + .compile("bindings"); + Ok(()) } diff --git a/rust/worker/src/errors.rs b/rust/worker/src/errors.rs index 5ae2b067707..c28d39ba9b7 100644 --- a/rust/worker/src/errors.rs +++ b/rust/worker/src/errors.rs @@ -2,6 +2,8 @@ // gRPC spec. https://grpc.github.io/grpc/core/md_doc_statuscodes.html // Custom errors can use these codes in order to allow for generic handling +use std::error::Error; + pub(crate) enum ErrorCodes { // OK is returned on success, we use "Success" since Ok is a keyword in Rust. Success = 0, @@ -39,6 +41,6 @@ pub(crate) enum ErrorCodes { DataLoss = 15, } -pub(crate) trait ChromaError { +pub(crate) trait ChromaError: Error { fn code(&self) -> ErrorCodes; } diff --git a/rust/worker/src/index/hnsw.rs b/rust/worker/src/index/hnsw.rs new file mode 100644 index 00000000000..3046b19d645 --- /dev/null +++ b/rust/worker/src/index/hnsw.rs @@ -0,0 +1,479 @@ +use std::ffi::CString; +use std::ffi::{c_char, c_int}; + +use crate::errors::{ChromaError, ErrorCodes}; + +use super::{Index, IndexConfig, PersistentIndex}; +use thiserror::Error; + +// https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs +#[repr(C)] +struct IndexPtrFFI { + _data: [u8; 0], + _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +// TODO: Make this config: +// - Watchable - for dynamic updates +// - Have a notion of static vs dynamic config +// - Have a notion of default config +// - HNSWIndex should store a ref to the config so it can look up the config values. +// deferring this for a config pass +#[derive(Clone, Debug)] +pub(crate) struct HnswIndexConfig { + pub(crate) max_elements: usize, + pub(crate) m: usize, + pub(crate) ef_construction: usize, + pub(crate) ef_search: usize, + pub(crate) random_seed: usize, + pub(crate) persist_path: String, +} + +#[repr(C)] +/// The HnswIndex struct. +/// # Description +/// This struct wraps a pointer to the C++ HnswIndex class and presents a safe Rust interface. +/// # Notes +/// This struct is not thread safe for concurrent reads and writes. Callers should +/// synchronize access to the index between reads and writes. +pub(crate) struct HnswIndex { + ffi_ptr: *const IndexPtrFFI, + dimensionality: i32, +} + +// Make index sync, we should wrap index so that it is sync in the way we expect but for now this implements the trait +unsafe impl Sync for HnswIndex {} +unsafe impl Send for HnswIndex {} + +#[derive(Error, Debug)] + +pub(crate) enum HnswIndexInitError { + #[error("No config provided")] + NoConfigProvided, + #[error("Invalid distance function `{0}`")] + InvalidDistanceFunction(String), + #[error("Invalid path `{0}`. Are you sure the path exists?")] + InvalidPath(String), +} + +impl ChromaError for HnswIndexInitError { + fn code(&self) -> ErrorCodes { + crate::errors::ErrorCodes::InvalidArgument + } +} + +impl Index for HnswIndex { + fn init( + index_config: &IndexConfig, + hnsw_config: Option<&HnswIndexConfig>, + ) -> Result> { + match hnsw_config { + None => return Err(Box::new(HnswIndexInitError::NoConfigProvided)), + Some(config) => { + let distance_function_string: String = + index_config.distance_function.clone().into(); + + let space_name = match CString::new(distance_function_string) { + Ok(space_name) => space_name, + Err(e) => { + return Err(Box::new(HnswIndexInitError::InvalidDistanceFunction( + e.to_string(), + ))) + } + }; + + let ffi_ptr = + unsafe { create_index(space_name.as_ptr(), index_config.dimensionality) }; + + let path = match CString::new(config.persist_path.clone()) { + Ok(path) => path, + Err(e) => return Err(Box::new(HnswIndexInitError::InvalidPath(e.to_string()))), + }; + + unsafe { + init_index( + ffi_ptr, + config.max_elements, + config.m, + config.ef_construction, + config.random_seed, + true, + true, + path.as_ptr(), + ); + } + + let hnsw_index = HnswIndex { + ffi_ptr: ffi_ptr, + dimensionality: index_config.dimensionality, + }; + hnsw_index.set_ef(config.ef_search); + Ok(hnsw_index) + } + } + } + + fn add(&self, id: usize, vector: &[f32]) { + unsafe { add_item(self.ffi_ptr, vector.as_ptr(), id, false) } + } + + fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec) { + let mut ids = vec![0usize; k]; + let mut distance = vec![0.0f32; k]; + unsafe { + knn_query( + self.ffi_ptr, + vector.as_ptr(), + k, + ids.as_mut_ptr(), + distance.as_mut_ptr(), + ); + } + return (ids, distance); + } + + fn get(&self, id: usize) -> Option> { + unsafe { + let mut data: Vec = vec![0.0f32; self.dimensionality as usize]; + get_item(self.ffi_ptr, id, data.as_mut_ptr()); + return Some(data); + } + } +} + +impl PersistentIndex for HnswIndex { + fn save(&self) -> Result<(), Box> { + unsafe { persist_dirty(self.ffi_ptr) }; + Ok(()) + } + + fn load(path: &str, index_config: &IndexConfig) -> Result> { + let distance_function_string: String = index_config.distance_function.clone().into(); + let space_name = match CString::new(distance_function_string) { + Ok(space_name) => space_name, + Err(e) => { + return Err(Box::new(HnswIndexInitError::InvalidDistanceFunction( + e.to_string(), + ))) + } + }; + let ffi_ptr = unsafe { create_index(space_name.as_ptr(), index_config.dimensionality) }; + let path = match CString::new(path.to_string()) { + Ok(path) => path, + Err(e) => return Err(Box::new(HnswIndexInitError::InvalidPath(e.to_string()))), + }; + unsafe { + load_index(ffi_ptr, path.as_ptr(), true, true); + } + let hnsw_index = HnswIndex { + ffi_ptr: ffi_ptr, + dimensionality: index_config.dimensionality, + }; + Ok(hnsw_index) + } +} + +impl HnswIndex { + pub fn set_ef(&self, ef: usize) { + unsafe { set_ef(self.ffi_ptr, ef as c_int) } + } + + pub fn get_ef(&self) -> usize { + unsafe { get_ef(self.ffi_ptr) as usize } + } +} + +#[link(name = "bindings", kind = "static")] +extern "C" { + fn create_index(space_name: *const c_char, dim: c_int) -> *const IndexPtrFFI; + + fn init_index( + index: *const IndexPtrFFI, + max_elements: usize, + M: usize, + ef_construction: usize, + random_seed: usize, + allow_replace_deleted: bool, + is_persistent: bool, + path: *const c_char, + ); + + fn load_index( + index: *const IndexPtrFFI, + path: *const c_char, + allow_replace_deleted: bool, + is_persistent_index: bool, + ); + + fn persist_dirty(index: *const IndexPtrFFI); + + fn add_item(index: *const IndexPtrFFI, data: *const f32, id: usize, replace_deleted: bool); + fn get_item(index: *const IndexPtrFFI, id: usize, data: *mut f32); + fn knn_query( + index: *const IndexPtrFFI, + query_vector: *const f32, + k: usize, + ids: *mut usize, + distance: *mut f32, + ); + + fn get_ef(index: *const IndexPtrFFI) -> c_int; + fn set_ef(index: *const IndexPtrFFI, ef: c_int); + +} + +#[cfg(test)] +pub mod test { + use super::*; + + use crate::index::types::DistanceFunction; + use crate::index::utils; + use rand::Rng; + use rayon::prelude::*; + use rayon::ThreadPoolBuilder; + use tempfile::tempdir; + + #[test] + fn it_initializes_and_can_set_get_ef() { + let n = 1000; + let d: usize = 960; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path().to_str().unwrap().to_string(); + let distance_function = DistanceFunction::Euclidean; + let mut index = HnswIndex::init( + &IndexConfig { + dimensionality: d as i32, + distance_function: distance_function, + }, + Some(&HnswIndexConfig { + max_elements: n, + m: 16, + ef_construction: 100, + ef_search: 10, + random_seed: 0, + persist_path: persist_path, + }), + ); + match index { + Err(e) => panic!("Error initializing index: {}", e), + Ok(index) => { + assert_eq!(index.get_ef(), 10); + index.set_ef(100); + assert_eq!(index.get_ef(), 100); + } + } + } + + #[test] + fn it_can_add_parallel() { + let n = 10; + let d: usize = 960; + let distance_function = DistanceFunction::InnerProduct; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path().to_str().unwrap().to_string(); + let index = HnswIndex::init( + &IndexConfig { + dimensionality: d as i32, + distance_function: distance_function, + }, + Some(&HnswIndexConfig { + max_elements: n, + m: 16, + ef_construction: 100, + ef_search: 100, + random_seed: 0, + persist_path: persist_path, + }), + ); + + let index = match index { + Err(e) => panic!("Error initializing index: {}", e), + Ok(index) => index, + }; + + let ids: Vec = (0..n).collect(); + + // Add data in parallel, using global pool for testing + ThreadPoolBuilder::new() + .num_threads(12) + .build_global() + .unwrap(); + + let mut rng: rand::prelude::ThreadRng = rand::thread_rng(); + let mut datas = Vec::new(); + for i in 0..n { + let mut data: Vec = Vec::new(); + for i in 0..960 { + data.push(rng.gen()); + } + datas.push(data); + } + + (0..n).into_par_iter().for_each(|i| { + let data = &datas[i]; + index.add(ids[i], data); + }); + + // Get the data and check it + let mut i = 0; + for id in ids { + let actual_data = index.get(id); + match actual_data { + None => panic!("No data found for id: {}", id), + Some(actual_data) => { + assert_eq!(actual_data.len(), d); + for j in 0..d { + // Floating point epsilon comparison + assert!((actual_data[j] - datas[i][j]).abs() < 0.00001); + } + } + } + i += 1; + } + } + + #[test] + fn it_can_add_and_basic_query() { + let n = 1; + let d: usize = 960; + let distance_function = DistanceFunction::Euclidean; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path().to_str().unwrap().to_string(); + let index = HnswIndex::init( + &IndexConfig { + dimensionality: d as i32, + distance_function: distance_function, + }, + Some(&HnswIndexConfig { + max_elements: n, + m: 16, + ef_construction: 100, + ef_search: 100, + random_seed: 0, + persist_path: persist_path, + }), + ); + + let index = match index { + Err(e) => panic!("Error initializing index: {}", e), + Ok(index) => index, + }; + assert_eq!(index.get_ef(), 100); + + let data: Vec = utils::generate_random_data(n, d); + let ids: Vec = (0..n).collect(); + + (0..n).into_iter().for_each(|i| { + let data = &data[i * d..(i + 1) * d]; + index.add(ids[i], data); + }); + + // Get the data and check it + let mut i = 0; + for id in ids { + let actual_data = index.get(id); + match actual_data { + None => panic!("No data found for id: {}", id), + Some(actual_data) => { + assert_eq!(actual_data.len(), d); + for j in 0..d { + // Floating point epsilon comparison + assert!((actual_data[j] - data[i * d + j]).abs() < 0.00001); + } + } + } + i += 1; + } + + // Query the data + let query = &data[0..d]; + let (ids, distances) = index.query(query, 1); + assert_eq!(ids.len(), 1); + assert_eq!(distances.len(), 1); + assert_eq!(ids[0], 0); + assert_eq!(distances[0], 0.0); + } + + #[test] + fn it_can_persist_and_load() { + let n = 1000; + let d: usize = 960; + let distance_function = DistanceFunction::Euclidean; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path().to_str().unwrap().to_string(); + let index = HnswIndex::init( + &IndexConfig { + dimensionality: d as i32, + distance_function: distance_function.clone(), + }, + Some(&HnswIndexConfig { + max_elements: n, + m: 32, + ef_construction: 100, + ef_search: 100, + random_seed: 0, + persist_path: persist_path.clone(), + }), + ); + + let index = match index { + Err(e) => panic!("Error initializing index: {}", e), + Ok(index) => index, + }; + + let data: Vec = utils::generate_random_data(n, d); + let ids: Vec = (0..n).collect(); + + (0..n).into_iter().for_each(|i| { + let data = &data[i * d..(i + 1) * d]; + index.add(ids[i], data); + }); + + // Persist the index + let res = index.save(); + match res { + Err(e) => panic!("Error saving index: {}", e), + Ok(_) => {} + } + + // Load the index + let index = HnswIndex::load( + &persist_path, + &IndexConfig { + dimensionality: d as i32, + distance_function: distance_function, + }, + ); + + let index = match index { + Err(e) => panic!("Error loading index: {}", e), + Ok(index) => index, + }; + // TODO: This should be set by the load + index.set_ef(100); + + // Query the data + let query = &data[0..d]; + let (ids, distances) = index.query(query, 1); + assert_eq!(ids.len(), 1); + assert_eq!(distances.len(), 1); + assert_eq!(ids[0], 0); + assert_eq!(distances[0], 0.0); + + // Get the data and check it + let mut i = 0; + for id in ids { + let actual_data = index.get(id); + match actual_data { + None => panic!("No data found for id: {}", id), + Some(actual_data) => { + assert_eq!(actual_data.len(), d); + for j in 0..d { + assert_eq!(actual_data[j], data[i * d + j]); + } + } + } + i += 1; + } + } +} diff --git a/rust/worker/src/index/mod.rs b/rust/worker/src/index/mod.rs new file mode 100644 index 00000000000..00738758407 --- /dev/null +++ b/rust/worker/src/index/mod.rs @@ -0,0 +1,6 @@ +mod hnsw; +mod types; +mod utils; + +// Re-export types +pub(crate) use types::*; diff --git a/rust/worker/src/index/types.rs b/rust/worker/src/index/types.rs new file mode 100644 index 00000000000..953e863b2a4 --- /dev/null +++ b/rust/worker/src/index/types.rs @@ -0,0 +1,98 @@ +use crate::errors::{ChromaError, ErrorCodes}; +use thiserror::Error; + +#[derive(Clone, Debug)] +pub(crate) struct IndexConfig { + pub(crate) dimensionality: i32, + pub(crate) distance_function: DistanceFunction, +} + +/// The index trait. +/// # Description +/// This trait defines the interface for a KNN index. +/// # Methods +/// - `init` - Initialize the index with a given dimension and distance function. +/// - `add` - Add a vector to the index. +/// - `query` - Query the index for the K nearest neighbors of a given vector. +pub(crate) trait Index { + fn init( + index_config: &IndexConfig, + custom_config: Option<&C>, + ) -> Result> + where + Self: Sized; + fn add(&self, id: usize, vector: &[f32]); + fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec); + fn get(&self, id: usize) -> Option>; +} + +/// The persistent index trait. +/// # Description +/// This trait defines the interface for a persistent KNN index. +/// # Methods +/// - `save` - Save the index to a given path. Configuration of the destination is up to the implementation. +/// - `load` - Load the index from a given path. +/// # Notes +/// This defines a rudimentary interface for saving and loading indices. +/// TODO: Right now load() takes IndexConfig because we don't implement save/load of the config. +pub(crate) trait PersistentIndex: Index { + fn save(&self) -> Result<(), Box>; + fn load(path: &str, index_config: &IndexConfig) -> Result> + where + Self: Sized; +} + +/// The distance function enum. +/// # Description +/// This enum defines the distance functions supported by indices in Chroma. +/// # Variants +/// - `Euclidean` - The Euclidean or l2 norm. +/// - `Cosine` - The cosine distance. Specifically, 1 - cosine. +/// - `InnerProduct` - The inner product. Specifically, 1 - inner product. +/// # Notes +/// See https://docs.trychroma.com/usage-guide#changing-the-distance-function +#[derive(Clone, Debug)] +pub(crate) enum DistanceFunction { + Euclidean, + Cosine, + InnerProduct, +} + +#[derive(Error, Debug)] +pub(crate) enum DistanceFunctionError { + #[error("Invalid distance function `{0}`")] + InvalidDistanceFunction(String), +} + +impl ChromaError for DistanceFunctionError { + fn code(&self) -> ErrorCodes { + match self { + DistanceFunctionError::InvalidDistanceFunction(_) => ErrorCodes::InvalidArgument, + } + } +} + +impl TryFrom<&str> for DistanceFunction { + type Error = DistanceFunctionError; + + fn try_from(value: &str) -> Result { + match value { + "l2" => Ok(DistanceFunction::Euclidean), + "cosine" => Ok(DistanceFunction::Cosine), + "ip" => Ok(DistanceFunction::InnerProduct), + _ => Err(DistanceFunctionError::InvalidDistanceFunction( + value.to_string(), + )), + } + } +} + +impl Into for DistanceFunction { + fn into(self) -> String { + match self { + DistanceFunction::Euclidean => "l2".to_string(), + DistanceFunction::Cosine => "cosine".to_string(), + DistanceFunction::InnerProduct => "ip".to_string(), + } + } +} diff --git a/rust/worker/src/index/utils.rs b/rust/worker/src/index/utils.rs new file mode 100644 index 00000000000..35d27a76e84 --- /dev/null +++ b/rust/worker/src/index/utils.rs @@ -0,0 +1,13 @@ +use rand::Rng; + +pub(super) fn generate_random_data(n: usize, d: usize) -> Vec { + let mut rng: rand::prelude::ThreadRng = rand::thread_rng(); + let mut data = vec![0.0f32; n * d]; + // Generate random data + for i in 0..n { + for j in 0..d { + data[i * d + j] = rng.gen(); + } + } + return data; +} diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index d48649febd1..7bb37357dc4 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -1,6 +1,7 @@ mod assignment; mod config; mod errors; +mod index; mod types; mod chroma_proto {