From 00c5f8c5432440fcfa97505aec49c7f297002fa2 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Thu, 31 Oct 2024 15:02:15 -0700 Subject: [PATCH] Cleanup hnsw provider to not know about segments --- rust/index/src/hnsw.rs | 124 +++-------- rust/index/src/hnsw_provider.rs | 123 ++++++----- rust/index/src/types.rs | 23 +- rust/types/src/metadata.rs | 17 ++ .../src/execution/orchestration/hnsw.rs | 27 +-- .../src/segment/distributed_hnsw_segment.rs | 202 ++++++++++++++---- 6 files changed, 280 insertions(+), 236 deletions(-) diff --git a/rust/index/src/hnsw.rs b/rust/index/src/hnsw.rs index d9a2aefc659..0fc48720647 100644 --- a/rust/index/src/hnsw.rs +++ b/rust/index/src/hnsw.rs @@ -3,15 +3,16 @@ use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Metadata, MetadataValue, MetadataValueConversionError, Segment}; use std::ffi::CString; use std::ffi::{c_char, c_int}; +use std::path::Path; use std::str::Utf8Error; use thiserror::Error; use tracing::instrument; use uuid::Uuid; -const DEFAULT_MAX_ELEMENTS: usize = 10000; -const DEFAULT_HNSW_M: usize = 16; -const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 100; -const DEFAULT_HNSW_EF_SEARCH: usize = 10; +pub const DEFAULT_MAX_ELEMENTS: usize = 10000; +pub const DEFAULT_HNSW_M: usize = 16; +pub const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 100; +pub const DEFAULT_HNSW_EF_SEARCH: usize = 10; // https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs #[repr(C)] @@ -51,10 +52,23 @@ impl ChromaError for HnswIndexFromSegmentError { } impl HnswIndexConfig { - pub fn from_segment( - segment: &Segment, - persist_path: &std::path::Path, - ) -> Result> { + pub fn new_default() -> Self { + HnswIndexConfig { + max_elements: DEFAULT_MAX_ELEMENTS, + m: DEFAULT_HNSW_M, + ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, + ef_search: DEFAULT_HNSW_EF_SEARCH, + random_seed: 0, + persist_path: "".to_string(), + } + } + + pub fn new( + m: usize, + ef_construction: usize, + ef_search: usize, + persist_path: &Path, + ) -> Result> { let persist_path = match persist_path.to_str() { Some(persist_path) => persist_path, None => { @@ -63,53 +77,11 @@ impl HnswIndexConfig { ))) } }; - let metadata = match &segment.metadata { - Some(metadata) => metadata, - None => { - // TODO: This should error, but the configuration is not stored correctly - // after the configuration is refactored to be always stored and doesn't rely on defaults we can fix this - return Ok(HnswIndexConfig { - max_elements: DEFAULT_MAX_ELEMENTS, - m: DEFAULT_HNSW_M, - ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, - ef_search: DEFAULT_HNSW_EF_SEARCH, - random_seed: 0, - persist_path: persist_path.to_string(), - }); - } - }; - - fn get_metadata_value_as<'a, T>( - metadata: &'a Metadata, - key: &str, - ) -> Result> - where - T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>, - { - let res = match metadata.get(key) { - Some(value) => T::try_from(value), - None => { - return Err(Box::new(HnswIndexFromSegmentError::MissingConfig( - key.to_string(), - ))) - } - }; - match res { - Ok(value) => Ok(value), - Err(e) => Err(Box::new(HnswIndexFromSegmentError::MetadataValueError(e))), - } - } - - let m = get_metadata_value_as::(metadata, "hnsw:M").unwrap_or(DEFAULT_HNSW_M as i64); - let ef_construction = get_metadata_value_as::(metadata, "hnsw:construction_ef") - .unwrap_or(DEFAULT_HNSW_EF_CONSTRUCTION as i64); - let ef_search = get_metadata_value_as::(metadata, "hnsw:search_ef") - .unwrap_or(DEFAULT_HNSW_EF_SEARCH as i64); Ok(HnswIndexConfig { max_elements: DEFAULT_MAX_ELEMENTS, - m: m as usize, - ef_construction: ef_construction as usize, - ef_search: ef_search as usize, + m, + ef_construction, + ef_search, random_seed: 0, persist_path: persist_path.to_string(), }) @@ -824,52 +796,6 @@ pub mod test { }); } - #[test] - fn parameter_defaults() { - let segment = Segment { - id: Uuid::new_v4(), - r#type: chroma_types::SegmentType::HnswDistributed, - scope: chroma_types::SegmentScope::VECTOR, - metadata: Some(HashMap::new()), - collection: Uuid::new_v4(), - file_path: HashMap::new(), - }; - - let persist_path = tempdir().unwrap().path().to_owned(); - let config = HnswIndexConfig::from_segment(&segment, &persist_path) - .expect("Failed to create config from segment"); - - assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); - assert_eq!(config.m, DEFAULT_HNSW_M); - assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); - assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); - assert_eq!(config.random_seed, 0); - assert_eq!(config.persist_path, persist_path.to_str().unwrap()); - - // Try partial metadata - let mut metadata = HashMap::new(); - metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10_i64)); - - let segment = Segment { - id: Uuid::new_v4(), - r#type: chroma_types::SegmentType::HnswDistributed, - scope: chroma_types::SegmentScope::VECTOR, - metadata: Some(metadata), - collection: Uuid::new_v4(), - file_path: HashMap::new(), - }; - - let config = HnswIndexConfig::from_segment(&segment, &persist_path) - .expect("Failed to create config from segment"); - - assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); - assert_eq!(config.m, 10); - assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); - assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); - assert_eq!(config.random_seed, 0); - assert_eq!(config.persist_path, persist_path.to_str().unwrap()); - } - #[test] fn it_can_catch_error() { let n = 10; diff --git a/rust/index/src/hnsw_provider.rs b/rust/index/src/hnsw_provider.rs index e64c24875fa..2291fdd1823 100644 --- a/rust/index/src/hnsw_provider.rs +++ b/rust/index/src/hnsw_provider.rs @@ -9,10 +9,10 @@ use super::{ use async_trait::async_trait; use chroma_cache::Cache; use chroma_config::Configurable; +use chroma_distance::DistanceFunction; use chroma_error::ChromaError; use chroma_error::ErrorCodes; use chroma_storage::Storage; -use chroma_types::Segment; use parking_lot::RwLock; use std::fmt::Debug; use std::path::Path; @@ -32,6 +32,12 @@ const FILES: [&str; 4] = [ "link_lists.bin", ]; +pub type HnswIndexParams = ( + usize, /* m */ + usize, /* ef_construction */ + usize, /* ef_search */ +); + // The key of the cache is the collection id and the value is // the HNSW index for that collection. This restricts the cache to // contain atmost one index per collection. Ideally, we would like @@ -143,8 +149,9 @@ impl HnswIndexProvider { pub async fn fork( &self, source_id: &Uuid, - segment: &Segment, + collection_id: &Uuid, dimensionality: i32, + distance_function: DistanceFunction, ) -> Result> { let new_id = Uuid::new_v4(); let new_storage_path = self.temporary_storage_path.join(new_id.to_string()); @@ -167,22 +174,7 @@ impl HnswIndexProvider { } } - let index_config = IndexConfig::from_segment(segment, dimensionality); - - let index_config = match index_config { - Ok(index_config) => index_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderForkError::IndexConfigError(*e))); - } - }; - - let hnsw_config = HnswIndexConfig::from_segment(segment, &new_storage_path); - match hnsw_config { - Ok(hnsw_config) => hnsw_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderForkError::HnswConfigError(*e))); - } - }; + let index_config = IndexConfig::new(dimensionality, distance_function); let storage_path_str = match new_storage_path.to_str() { Some(storage_path_str) => storage_path_str, @@ -196,13 +188,15 @@ impl HnswIndexProvider { match HnswIndex::load(storage_path_str, &index_config, new_id) { Ok(index) => { let _guard = self.write_mutex.lock().await; - match self.get(&new_id, &segment.collection).await { + match self.get(&new_id, collection_id).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(segment.collection, index.clone()).await; + self.cache + .insert(collection_id.clone(), index.clone()) + .await; Ok(index) } } @@ -281,8 +275,9 @@ impl HnswIndexProvider { pub async fn open( &self, id: &Uuid, - segment: &Segment, + collection_id: &Uuid, dimensionality: i32, + distance_function: DistanceFunction, ) -> Result> { let index_storage_path = self.temporary_storage_path.join(id.to_string()); @@ -305,17 +300,7 @@ impl HnswIndexProvider { } // Thread safe. - let index_config = IndexConfig::from_segment(segment, dimensionality); - let index_config = match index_config { - Ok(index_config) => index_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderOpenError::IndexConfigError(*e))); - } - }; - - // Thread safe. - let _hnsw_config = HnswIndexConfig::from_segment(segment, &index_storage_path) - .map_err(|e| Box::new(HnswIndexProviderOpenError::HnswConfigError(*e)))?; + let index_config = IndexConfig::new(dimensionality, distance_function); let index_storage_path_str = match index_storage_path.to_str() { Some(index_storage_path_str) => index_storage_path_str, @@ -329,13 +314,13 @@ impl HnswIndexProvider { match HnswIndex::load(index_storage_path_str, &index_config, *id) { Ok(index) => { let _guard = self.write_mutex.lock().await; - match self.get(id, &segment.collection).await { + match self.get(id, collection_id).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(segment.collection, index.clone()).await; + self.cache.insert(*collection_id, index.clone()).await; Ok(index) } } @@ -356,9 +341,11 @@ impl HnswIndexProvider { // A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id pub async fn create( &self, - // TODO: This should not take Segment. The index layer should not know about the segment concept - segment: &Segment, + collection_id: &Uuid, + hnsw_params: HnswIndexParams, + persist_path: &std::path::Path, dimensionality: i32, + distance_function: DistanceFunction, ) -> Result> { let id = Uuid::new_v4(); let index_storage_path = self.temporary_storage_path.join(id.to_string()); @@ -370,31 +357,30 @@ impl HnswIndexProvider { } } - let index_config = match IndexConfig::from_segment(segment, dimensionality) { - Ok(index_config) => index_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderCreateError::IndexConfigError(*e))); - } - }; + let index_config = IndexConfig::new(dimensionality, distance_function); + + let hnsw_config = + match HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, persist_path) { + Ok(hnsw_config) => hnsw_config, + Err(e) => { + return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e))); + } + }; - let hnsw_config = match HnswIndexConfig::from_segment(segment, &index_storage_path) { - Ok(hnsw_config) => hnsw_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e))); - } - }; // HnswIndex init is not thread safe. We should not call it from multiple threads let index = HnswIndex::init(&index_config, Some(&hnsw_config), id) .map_err(|e| Box::new(HnswIndexProviderCreateError::IndexInitError(e)))?; let _guard = self.write_mutex.lock().await; - match self.get(&id, &segment.collection).await { + match self.get(&id, &collection_id).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(segment.collection, index.clone()).await; + self.cache + .insert(collection_id.clone(), index.clone()) + .await; Ok(index) } } @@ -586,11 +572,11 @@ pub enum HnswIndexProviderFileError { #[cfg(test)] mod tests { + use crate::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; + use super::*; use chroma_cache::new_non_persistent_cache_for_test; use chroma_storage::local::LocalStorage; - use chroma_types::SegmentType; - use std::collections::HashMap; #[tokio::test] async fn test_fork() { @@ -604,21 +590,34 @@ mod tests { let cache = new_non_persistent_cache_for_test(); let (_tx, rx) = tokio::sync::mpsc::unbounded_channel(); let provider = HnswIndexProvider::new(storage, hnsw_tmp_path, cache, rx); - let segment = Segment { - id: Uuid::new_v4(), - r#type: SegmentType::HnswDistributed, - scope: chroma_types::SegmentScope::VECTOR, - collection: Uuid::new_v4(), - metadata: None, - file_path: HashMap::new(), - }; + let collection_id = Uuid::new_v4(); let dimensionality = 128; - let created_index = provider.create(&segment, dimensionality).await.unwrap(); + let hnsw_params = ( + DEFAULT_HNSW_M, + DEFAULT_HNSW_EF_CONSTRUCTION, + DEFAULT_HNSW_EF_SEARCH, + ); + let distance_function = DistanceFunction::Euclidean; + let created_index = provider + .create( + &collection_id, + hnsw_params, + &provider.temporary_storage_path, + dimensionality, + distance_function.clone(), + ) + .await + .unwrap(); let created_index_id = created_index.inner.read().id; let forked_index = provider - .fork(&created_index_id, &segment, dimensionality) + .fork( + &created_index_id, + &collection_id, + dimensionality, + distance_function, + ) .await .unwrap(); let forked_index_id = forked_index.inner.read().id; diff --git a/rust/index/src/types.rs b/rust/index/src/types.rs index e6eae433477..be5962703f3 100644 --- a/rust/index/src/types.rs +++ b/rust/index/src/types.rs @@ -25,25 +25,10 @@ impl ChromaError for IndexConfigFromSegmentError { } impl IndexConfig { - pub fn from_segment( - segment: &Segment, - dimensionality: i32, - ) -> Result> { - let space = match segment.metadata { - Some(ref metadata) => match metadata.get("hnsw:space") { - Some(MetadataValue::Str(space)) => space, - _ => "l2", - }, - None => "l2", - }; - match DistanceFunction::try_from(space) { - Ok(distance_function) => Ok(IndexConfig { - dimensionality, - distance_function, - }), - Err(e) => Err(Box::new( - IndexConfigFromSegmentError::InvalidDistanceFunction(e), - )), + pub fn new(dimensionality: i32, distance_function: DistanceFunction) -> Self { + IndexConfig { + dimensionality, + distance_function, } } } diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index dc5293d22bf..1401228645e 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -278,6 +278,23 @@ Metadata pub type Metadata = HashMap; pub type DeletedMetadata = HashSet; +pub fn get_metadata_value_as<'a, T>( + metadata: &'a Metadata, + key: &str, +) -> Result> +where + T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>, +{ + let res = match metadata.get(key) { + Some(value) => T::try_from(value), + None => return Err(Box::new(MetadataValueConversionError::InvalidValue)), + }; + match res { + Ok(value) => Ok(value), + Err(e) => Err(Box::new(MetadataValueConversionError::InvalidValue)), + } +} + impl TryFrom for Metadata { type Error = MetadataValueConversionError; diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 8b59e593290..f0021986ca4 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -26,7 +26,8 @@ use crate::execution::operators::record_segment_prefetch::{ }; use crate::log::log::PullLogsError; use crate::segment::distributed_hnsw_segment::{ - DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader, + distance_function_from_segment, DistributedHNSWSegmentFromSegmentError, + DistributedHNSWSegmentReader, }; use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}; use crate::system::{ComponentContext, ComponentHandle, System}; @@ -616,22 +617,22 @@ impl Component for HnswQueryOrchestrator { } }; - match IndexConfig::from_segment(&hnsw_segment, collection.dimension.unwrap()) { - Ok(index_config) => { - self.index_config = Some(index_config); - - // Normalize the query vectors if we are using the cosine similarity - if self.index_config.as_ref().unwrap().distance_function == DistanceFunction::Cosine - { - for query_vector in self.query_vectors.iter_mut() { - *query_vector = normalize(query_vector); - } - } - } + let distance_function = match distance_function_from_segment(&hnsw_segment) { + Ok(distance_function) => distance_function, Err(e) => { terminate_with_error(self.result_channel.take(), e, ctx); return; } + }; + self.index_config = Some(IndexConfig::new( + collection.dimension.unwrap(), + distance_function, + )); + // Normalize the query vectors if we are using the cosine similarity + if self.index_config.as_ref().unwrap().distance_function == DistanceFunction::Cosine { + for query_vector in self.query_vectors.iter_mut() { + *query_vector = normalize(query_vector); + } } self.record_segment = Some(record_segment); diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index 413a9e8c591..a86d969e372 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -1,15 +1,14 @@ use super::record_segment::ApplyMaterializedLogError; use super::{SegmentFlusher, SegmentWriter}; use async_trait::async_trait; +use chroma_distance::{DistanceFunction, DistanceFunctionError}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::{ - HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, + HnswIndexParams, HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexProviderOpenError, HnswIndexRef, }; -use chroma_index::{ - HnswIndexConfig, HnswIndexFromSegmentError, Index, IndexConfig, IndexConfigFromSegmentError, -}; -use chroma_types::{MaterializedLogOperation, Segment}; +use chroma_index::{Index, DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; +use chroma_types::{get_metadata_value_as, MaterializedLogOperation, MetadataValue, Segment}; use std::collections::HashMap; use std::fmt::Debug; use thiserror::Error; @@ -40,16 +39,14 @@ pub enum DistributedHNSWSegmentFromSegmentError { InvalidUUID, #[error("HNSW segment uninitialized")] Uninitialized, - #[error("Index configuration error")] - IndexConfigError(#[from] IndexConfigFromSegmentError), - #[error("HNSW index configuration error")] - HnswIndexConfigError(#[from] HnswIndexFromSegmentError), #[error("HNSW index provider open error")] HnswIndexProviderOpenError(#[from] HnswIndexProviderOpenError), #[error("HNSW index provider fork error")] HnswIndexProviderForkError(#[from] HnswIndexProviderForkError), #[error("HNSW index provider create error")] HnswIndexProviderCreateError(#[from] HnswIndexProviderCreateError), + #[error("Error extracting distance function")] + DistanceFunctionError(#[from] DistanceFunctionError), } impl ChromaError for DistributedHNSWSegmentFromSegmentError { @@ -58,15 +55,60 @@ impl ChromaError for DistributedHNSWSegmentFromSegmentError { DistributedHNSWSegmentFromSegmentError::NoHnswFileFound => ErrorCodes::NotFound, DistributedHNSWSegmentFromSegmentError::InvalidUUID => ErrorCodes::InvalidArgument, DistributedHNSWSegmentFromSegmentError::Uninitialized => ErrorCodes::InvalidArgument, - DistributedHNSWSegmentFromSegmentError::IndexConfigError(e) => e.code(), - DistributedHNSWSegmentFromSegmentError::HnswIndexConfigError(e) => e.code(), DistributedHNSWSegmentFromSegmentError::HnswIndexProviderOpenError(e) => e.code(), DistributedHNSWSegmentFromSegmentError::HnswIndexProviderForkError(e) => e.code(), DistributedHNSWSegmentFromSegmentError::HnswIndexProviderCreateError(e) => e.code(), + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(e) => e.code(), } } } +fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { + let metadata = match &segment.metadata { + Some(metadata) => metadata, + None => { + return ( + DEFAULT_HNSW_M, + DEFAULT_HNSW_EF_CONSTRUCTION, + DEFAULT_HNSW_EF_SEARCH, + ); + } + }; + + let m = match get_metadata_value_as::(metadata, "hnsw:M") { + Ok(m) => m as usize, + Err(_) => DEFAULT_HNSW_M, + }; + let ef_construction = match get_metadata_value_as::(metadata, "hnsw:construction_ef") { + Ok(ef_construction) => ef_construction as usize, + Err(_) => DEFAULT_HNSW_EF_CONSTRUCTION, + }; + let ef_search = match get_metadata_value_as::(metadata, "hnsw:search_ef") { + Ok(ef_search) => ef_search as usize, + Err(_) => DEFAULT_HNSW_EF_SEARCH, + }; + + (m, ef_construction, ef_search) +} + +pub fn distance_function_from_segment( + segment: &Segment, +) -> Result> { + let space = match segment.metadata { + Some(ref metadata) => match metadata.get("hnsw:space") { + Some(MetadataValue::Str(space)) => space, + _ => "l2", + }, + None => "l2", + }; + match DistanceFunction::try_from(space) { + Ok(distance_function) => Ok(distance_function), + Err(e) => Err(Box::new( + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(e), + )), + } +} + impl DistributedHNSWSegmentWriter { pub(crate) fn new( index: HnswIndexRef, @@ -86,25 +128,7 @@ impl DistributedHNSWSegmentWriter { hnsw_index_provider: HnswIndexProvider, ) -> Result, Box> { - let _index_config = match IndexConfig::from_segment(segment, dimensionality as i32) { - Ok(ic) => ic, - Err(e) => { - return Err(Box::new( - DistributedHNSWSegmentFromSegmentError::IndexConfigError(*e), - )); - } - }; let persist_path = &hnsw_index_provider.temporary_storage_path; - - let _hnsw_config = match HnswIndexConfig::from_segment(segment, persist_path) { - Ok(hc) => hc, - Err(e) => { - return Err(Box::new( - DistributedHNSWSegmentFromSegmentError::HnswIndexConfigError(*e), - )); - } - }; - // TODO: this is hacky, we use the presence of files to determine if we need to load or create the index // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this @@ -137,8 +161,20 @@ impl DistributedHNSWSegmentWriter { } }; + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(e); + } + }; + let index = match hnsw_index_provider - .fork(&index_uuid, segment, dimensionality as i32) + .fork( + &index_uuid, + &segment.collection, + dimensionality as i32, + distance_function, + ) .await { Ok(index) => index, @@ -155,8 +191,22 @@ impl DistributedHNSWSegmentWriter { segment.id, ))) } else { + let hnsw_params = hnsw_params_from_segment(segment); + + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(e); + } + }; let index = match hnsw_index_provider - .create(segment, dimensionality as i32) + .create( + &segment.collection, + hnsw_params, + persist_path, + dimensionality as i32, + distance_function, + ) .await { Ok(index) => index, @@ -277,17 +327,6 @@ impl DistributedHNSWSegmentReader { hnsw_index_provider: HnswIndexProvider, ) -> Result, Box> { - let index_config = IndexConfig::from_segment(segment, dimensionality as i32); - let _index_config = match index_config { - Ok(ic) => ic, - Err(e) => { - return Err(Box::new( - DistributedHNSWSegmentFromSegmentError::IndexConfigError(*e), - )); - } - }; - let _persist_path = &hnsw_index_provider.temporary_storage_path; - // TODO: this is hacky, we use the presence of files to determine if we need to load or create the index // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this @@ -327,8 +366,19 @@ impl DistributedHNSWSegmentReader { { Some(index) => index, None => { + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(e); + } + }; match hnsw_index_provider - .open(&index_uuid, segment, dimensionality as i32) + .open( + &index_uuid, + &segment.collection, + dimensionality as i32, + distance_function, + ) .await { Ok(index) => index, @@ -362,3 +412,69 @@ impl DistributedHNSWSegmentReader { index.query(vector, k, allowed_ids, disallowd_ids) } } + +#[cfg(test)] +pub mod test { + use std::collections::HashMap; + + use chroma_index::{ + HnswIndexConfig, DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M, + DEFAULT_MAX_ELEMENTS, + }; + use chroma_types::{MetadataValue, Segment}; + use tempfile::tempdir; + use uuid::Uuid; + + use crate::segment::distributed_hnsw_segment::hnsw_params_from_segment; + + #[test] + fn parameter_defaults() { + let persist_path = tempdir().unwrap().path().to_owned(); + + let segment = Segment { + id: Uuid::new_v4(), + r#type: chroma_types::SegmentType::HnswDistributed, + scope: chroma_types::SegmentScope::VECTOR, + metadata: Some(HashMap::new()), + collection: Uuid::new_v4(), + file_path: HashMap::new(), + }; + + let hnsw_params = hnsw_params_from_segment(&segment); + let config = + HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, &persist_path) + .expect("Error creating hnsw index config"); + + assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); + assert_eq!(config.m, DEFAULT_HNSW_M); + assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); + assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); + assert_eq!(config.random_seed, 0); + assert_eq!(config.persist_path, persist_path.to_str().unwrap()); + + // Try partial metadata + let mut metadata = HashMap::new(); + metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10_i64)); + + let segment = Segment { + id: Uuid::new_v4(), + r#type: chroma_types::SegmentType::HnswDistributed, + scope: chroma_types::SegmentScope::VECTOR, + metadata: Some(metadata), + collection: Uuid::new_v4(), + file_path: HashMap::new(), + }; + + let hnsw_params = hnsw_params_from_segment(&segment); + let config = + HnswIndexConfig::new(hnsw_params.0, hnsw_params.1, hnsw_params.2, &persist_path) + .expect("Error creating hnsw index config"); + + assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); + assert_eq!(config.m, 10); + assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); + assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); + assert_eq!(config.random_seed, 0); + assert_eq!(config.persist_path, persist_path.to_str().unwrap()); + } +}