diff --git a/rust/index/src/hnsw.rs b/rust/index/src/hnsw.rs index 9a3ee89d879..e97b04a28fe 100644 --- a/rust/index/src/hnsw.rs +++ b/rust/index/src/hnsw.rs @@ -1,16 +1,17 @@ use super::{Index, IndexConfig, IndexUuid, PersistentIndex}; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Metadata, MetadataValue, MetadataValueConversionError, Segment}; +use chroma_types::MetadataValueConversionError; use std::ffi::CString; use std::ffi::{c_char, c_int}; +use std::path::Path; use std::str::Utf8Error; use thiserror::Error; use tracing::instrument; -const DEFAULT_MAX_ELEMENTS: usize = 10000; -const DEFAULT_HNSW_M: usize = 16; -const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 100; -const DEFAULT_HNSW_EF_SEARCH: usize = 10; +pub const DEFAULT_MAX_ELEMENTS: usize = 10000; +pub const DEFAULT_HNSW_M: usize = 16; +pub const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 100; +pub const DEFAULT_HNSW_EF_SEARCH: usize = 10; // https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs #[repr(C)] @@ -50,10 +51,12 @@ impl ChromaError for HnswIndexFromSegmentError { } impl HnswIndexConfig { - pub fn from_segment( - segment: &Segment, - persist_path: &std::path::Path, - ) -> Result> { + pub fn new( + m: usize, + ef_construction: usize, + ef_search: usize, + persist_path: &Path, + ) -> Result> { let persist_path = match persist_path.to_str() { Some(persist_path) => persist_path, None => { @@ -62,53 +65,11 @@ impl HnswIndexConfig { ))) } }; - let metadata = match &segment.metadata { - Some(metadata) => metadata, - None => { - // TODO: This should error, but the configuration is not stored correctly - // after the configuration is refactored to be always stored and doesn't rely on defaults we can fix this - return Ok(HnswIndexConfig { - max_elements: DEFAULT_MAX_ELEMENTS, - m: DEFAULT_HNSW_M, - ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, - ef_search: DEFAULT_HNSW_EF_SEARCH, - random_seed: 0, - persist_path: persist_path.to_string(), - }); - } - }; - - fn get_metadata_value_as<'a, T>( - metadata: &'a Metadata, - key: &str, - ) -> Result> - where - T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>, - { - let res = match metadata.get(key) { - Some(value) => T::try_from(value), - None => { - return Err(Box::new(HnswIndexFromSegmentError::MissingConfig( - key.to_string(), - ))) - } - }; - match res { - Ok(value) => Ok(value), - Err(e) => Err(Box::new(HnswIndexFromSegmentError::MetadataValueError(e))), - } - } - - let m = get_metadata_value_as::(metadata, "hnsw:M").unwrap_or(DEFAULT_HNSW_M as i64); - let ef_construction = get_metadata_value_as::(metadata, "hnsw:construction_ef") - .unwrap_or(DEFAULT_HNSW_EF_CONSTRUCTION as i64); - let ef_search = get_metadata_value_as::(metadata, "hnsw:search_ef") - .unwrap_or(DEFAULT_HNSW_EF_SEARCH as i64); Ok(HnswIndexConfig { max_elements: DEFAULT_MAX_ELEMENTS, - m: m as usize, - ef_construction: ef_construction as usize, - ef_search: ef_search as usize, + m, + ef_construction, + ef_search, random_seed: 0, persist_path: persist_path.to_string(), }) @@ -428,12 +389,9 @@ pub mod test { use super::*; use crate::utils; use chroma_distance::DistanceFunction; - use chroma_types::CollectionUuid; - use chroma_types::SegmentUuid; use rand::seq::IteratorRandom; use rayon::prelude::*; use rayon::ThreadPoolBuilder; - use std::collections::HashMap; use tempfile::tempdir; use uuid::Uuid; @@ -826,52 +784,6 @@ pub mod test { }); } - #[test] - fn parameter_defaults() { - let segment = Segment { - id: SegmentUuid::new(), - r#type: chroma_types::SegmentType::HnswDistributed, - scope: chroma_types::SegmentScope::VECTOR, - metadata: Some(HashMap::new()), - collection: CollectionUuid(Uuid::new_v4()), - file_path: HashMap::new(), - }; - - let persist_path = tempdir().unwrap().path().to_owned(); - let config = HnswIndexConfig::from_segment(&segment, &persist_path) - .expect("Failed to create config from segment"); - - assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); - assert_eq!(config.m, DEFAULT_HNSW_M); - assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); - assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); - assert_eq!(config.random_seed, 0); - assert_eq!(config.persist_path, persist_path.to_str().unwrap()); - - // Try partial metadata - let mut metadata = HashMap::new(); - metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10_i64)); - - let segment = Segment { - id: SegmentUuid::new(), - r#type: chroma_types::SegmentType::HnswDistributed, - scope: chroma_types::SegmentScope::VECTOR, - metadata: Some(metadata), - collection: CollectionUuid(Uuid::new_v4()), - file_path: HashMap::new(), - }; - - let config = HnswIndexConfig::from_segment(&segment, &persist_path) - .expect("Failed to create config from segment"); - - assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); - assert_eq!(config.m, 10); - assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); - assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); - assert_eq!(config.random_seed, 0); - assert_eq!(config.persist_path, persist_path.to_str().unwrap()); - } - #[test] fn it_can_catch_error() { let n = 10; diff --git a/rust/index/src/hnsw_provider.rs b/rust/index/src/hnsw_provider.rs index 63c9f252cb8..c4f4a463045 100644 --- a/rust/index/src/hnsw_provider.rs +++ b/rust/index/src/hnsw_provider.rs @@ -9,10 +9,11 @@ use super::{ use async_trait::async_trait; use chroma_cache::{Cache, Weighted}; use chroma_config::Configurable; +use chroma_distance::DistanceFunction; use chroma_error::ChromaError; use chroma_error::ErrorCodes; use chroma_storage::Storage; -use chroma_types::{CollectionUuid, Segment}; +use chroma_types::CollectionUuid; use parking_lot::RwLock; use std::fmt::Debug; use std::path::Path; @@ -32,6 +33,8 @@ const FILES: [&str; 4] = [ "link_lists.bin", ]; +type CacheKey = CollectionUuid; + // The key of the cache is the collection id and the value is // the HNSW index for that collection. This restricts the cache to // contain atmost one index per collection. Ideally, we would like @@ -133,12 +136,8 @@ impl HnswIndexProvider { } } - pub async fn get( - &self, - index_id: &IndexUuid, - collection_id: &CollectionUuid, - ) -> Option { - match self.cache.get(collection_id).await.ok().flatten() { + pub async fn get(&self, index_id: &IndexUuid, cache_key: &CacheKey) -> Option { + match self.cache.get(cache_key).await.ok().flatten() { Some(index) => { let index_with_lock = index.inner.read(); if index_with_lock.id == *index_id { @@ -159,8 +158,9 @@ impl HnswIndexProvider { pub async fn fork( &self, source_id: &IndexUuid, - segment: &Segment, + cache_key: &CacheKey, dimensionality: i32, + distance_function: DistanceFunction, ) -> Result> { let new_id = IndexUuid(Uuid::new_v4()); let new_storage_path = self.temporary_storage_path.join(new_id.to_string()); @@ -183,22 +183,7 @@ impl HnswIndexProvider { } } - let index_config = IndexConfig::from_segment(segment, dimensionality); - - let index_config = match index_config { - Ok(index_config) => index_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderForkError::IndexConfigError(*e))); - } - }; - - let hnsw_config = HnswIndexConfig::from_segment(segment, &new_storage_path); - match hnsw_config { - Ok(hnsw_config) => hnsw_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderForkError::HnswConfigError(*e))); - } - }; + let index_config = IndexConfig::new(dimensionality, distance_function); let storage_path_str = match new_storage_path.to_str() { Some(storage_path_str) => storage_path_str, @@ -212,13 +197,13 @@ impl HnswIndexProvider { match HnswIndex::load(storage_path_str, &index_config, new_id) { Ok(index) => { let _guard = self.write_mutex.lock().await; - match self.get(&new_id, &segment.collection).await { + match self.get(&new_id, cache_key).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(segment.collection, index.clone()).await; + self.cache.insert(*cache_key, index.clone()).await; Ok(index) } } @@ -303,8 +288,9 @@ impl HnswIndexProvider { pub async fn open( &self, id: &IndexUuid, - segment: &Segment, + cache_key: &CacheKey, dimensionality: i32, + distance_function: DistanceFunction, ) -> Result> { let index_storage_path = self.temporary_storage_path.join(id.to_string()); @@ -327,17 +313,7 @@ impl HnswIndexProvider { } // Thread safe. - let index_config = IndexConfig::from_segment(segment, dimensionality); - let index_config = match index_config { - Ok(index_config) => index_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderOpenError::IndexConfigError(*e))); - } - }; - - // Thread safe. - let _hnsw_config = HnswIndexConfig::from_segment(segment, &index_storage_path) - .map_err(|e| Box::new(HnswIndexProviderOpenError::HnswConfigError(*e)))?; + let index_config = IndexConfig::new(dimensionality, distance_function); let index_storage_path_str = match index_storage_path.to_str() { Some(index_storage_path_str) => index_storage_path_str, @@ -351,13 +327,13 @@ impl HnswIndexProvider { match HnswIndex::load(index_storage_path_str, &index_config, *id) { Ok(index) => { let _guard = self.write_mutex.lock().await; - match self.get(id, &segment.collection).await { + match self.get(id, cache_key).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(segment.collection, index.clone()).await; + self.cache.insert(*cache_key, index.clone()).await; Ok(index) } } @@ -378,9 +354,12 @@ impl HnswIndexProvider { // A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id pub async fn create( &self, - // TODO: This should not take Segment. The index layer should not know about the segment concept - segment: &Segment, + cache_key: &CacheKey, + m: usize, + ef_construction: usize, + ef_search: usize, dimensionality: i32, + distance_function: DistanceFunction, ) -> Result> { let id = IndexUuid(Uuid::new_v4()); let index_storage_path = self.temporary_storage_path.join(id.to_string()); @@ -392,31 +371,28 @@ impl HnswIndexProvider { } } - let index_config = match IndexConfig::from_segment(segment, dimensionality) { - Ok(index_config) => index_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderCreateError::IndexConfigError(*e))); - } - }; + let index_config = IndexConfig::new(dimensionality, distance_function); + + let hnsw_config = + match HnswIndexConfig::new(m, ef_construction, ef_search, &index_storage_path) { + Ok(hnsw_config) => hnsw_config, + Err(e) => { + return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e))); + } + }; - let hnsw_config = match HnswIndexConfig::from_segment(segment, &index_storage_path) { - Ok(hnsw_config) => hnsw_config, - Err(e) => { - return Err(Box::new(HnswIndexProviderCreateError::HnswConfigError(*e))); - } - }; // HnswIndex init is not thread safe. We should not call it from multiple threads let index = HnswIndex::init(&index_config, Some(&hnsw_config), id) .map_err(|e| Box::new(HnswIndexProviderCreateError::IndexInitError(e)))?; let _guard = self.write_mutex.lock().await; - match self.get(&id, &segment.collection).await { + match self.get(&id, cache_key).await { Some(index) => Ok(index.clone()), None => { let index = HnswIndexRef { inner: Arc::new(RwLock::new(index)), }; - self.cache.insert(segment.collection, index.clone()).await; + self.cache.insert(*cache_key, index.clone()).await; Ok(index) } } @@ -455,8 +431,8 @@ impl HnswIndexProvider { } /// Purge entries from the cache by index ID and remove temporary files from disk. - pub async fn purge_by_id(&mut self, collection_uuids: &[CollectionUuid]) { - for collection_uuid in collection_uuids { + pub async fn purge_by_id(&mut self, cache_keys: &[CacheKey]) { + for collection_uuid in cache_keys { let Some(index_id) = self .cache .get(collection_uuid) @@ -619,11 +595,11 @@ pub enum HnswIndexProviderFileError { #[cfg(test)] mod tests { + use crate::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; + use super::*; use chroma_cache::new_non_persistent_cache_for_test; use chroma_storage::local::LocalStorage; - use chroma_types::{SegmentType, SegmentUuid}; - use std::collections::HashMap; #[tokio::test] async fn test_fork() { @@ -637,21 +613,30 @@ mod tests { let cache = new_non_persistent_cache_for_test(); let (_tx, rx) = tokio::sync::mpsc::unbounded_channel(); let provider = HnswIndexProvider::new(storage, hnsw_tmp_path, cache, rx); - let segment = Segment { - id: SegmentUuid::new(), - r#type: SegmentType::HnswDistributed, - scope: chroma_types::SegmentScope::VECTOR, - collection: CollectionUuid(Uuid::new_v4()), - metadata: None, - file_path: HashMap::new(), - }; + let collection_id = CollectionUuid(Uuid::new_v4()); let dimensionality = 128; - let created_index = provider.create(&segment, dimensionality).await.unwrap(); + let distance_function = DistanceFunction::Euclidean; + let created_index = provider + .create( + &collection_id, + DEFAULT_HNSW_M, + DEFAULT_HNSW_EF_CONSTRUCTION, + DEFAULT_HNSW_EF_SEARCH, + dimensionality, + distance_function.clone(), + ) + .await + .unwrap(); let created_index_id = created_index.inner.read().id; let forked_index = provider - .fork(&created_index_id, &segment, dimensionality) + .fork( + &created_index_id, + &collection_id, + dimensionality, + distance_function, + ) .await .unwrap(); let forked_index_id = forked_index.inner.read().id; diff --git a/rust/index/src/types.rs b/rust/index/src/types.rs index ae27f698b0e..f34db6d83ac 100644 --- a/rust/index/src/types.rs +++ b/rust/index/src/types.rs @@ -1,6 +1,5 @@ use chroma_distance::{DistanceFunction, DistanceFunctionError}; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{MetadataValue, Segment}; use thiserror::Error; use uuid::Uuid; @@ -25,25 +24,10 @@ impl ChromaError for IndexConfigFromSegmentError { } impl IndexConfig { - pub fn from_segment( - segment: &Segment, - dimensionality: i32, - ) -> Result> { - let space = match segment.metadata { - Some(ref metadata) => match metadata.get("hnsw:space") { - Some(MetadataValue::Str(space)) => space, - _ => "l2", - }, - None => "l2", - }; - match DistanceFunction::try_from(space) { - Ok(distance_function) => Ok(IndexConfig { - dimensionality, - distance_function, - }), - Err(e) => Err(Box::new( - IndexConfigFromSegmentError::InvalidDistanceFunction(e), - )), + pub fn new(dimensionality: i32, distance_function: DistanceFunction) -> Self { + IndexConfig { + dimensionality, + distance_function, } } } diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index dc5293d22bf..e1dc9ed8b53 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -278,6 +278,23 @@ Metadata pub type Metadata = HashMap; pub type DeletedMetadata = HashSet; +pub fn get_metadata_value_as<'a, T>( + metadata: &'a Metadata, + key: &str, +) -> Result> +where + T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>, +{ + let res = match metadata.get(key) { + Some(value) => T::try_from(value), + None => return Err(Box::new(MetadataValueConversionError::InvalidValue)), + }; + match res { + Ok(value) => Ok(value), + Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)), + } +} + impl TryFrom for Metadata { type Error = MetadataValueConversionError; diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 7abfcd3e699..fbc19072f8c 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -26,7 +26,8 @@ use crate::execution::operators::record_segment_prefetch::{ }; use crate::log::log::PullLogsError; use crate::segment::distributed_hnsw_segment::{ - DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader, + distance_function_from_segment, DistributedHNSWSegmentFromSegmentError, + DistributedHNSWSegmentReader, }; use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}; use crate::system::{ComponentContext, ComponentHandle, System}; @@ -620,22 +621,22 @@ impl Component for HnswQueryOrchestrator { } }; - match IndexConfig::from_segment(&hnsw_segment, collection.dimension.unwrap()) { - Ok(index_config) => { - self.index_config = Some(index_config); - - // Normalize the query vectors if we are using the cosine similarity - if self.index_config.as_ref().unwrap().distance_function == DistanceFunction::Cosine - { - for query_vector in self.query_vectors.iter_mut() { - *query_vector = normalize(query_vector); - } - } - } + let distance_function = match distance_function_from_segment(&hnsw_segment) { + Ok(distance_function) => distance_function, Err(e) => { terminate_with_error(self.result_channel.take(), e, ctx); return; } + }; + self.index_config = Some(IndexConfig::new( + collection.dimension.unwrap(), + distance_function, + )); + // Normalize the query vectors if we are using the cosine similarity + if self.index_config.as_ref().unwrap().distance_function == DistanceFunction::Cosine { + for query_vector in self.query_vectors.iter_mut() { + *query_vector = normalize(query_vector); + } } self.record_segment = Some(record_segment); diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index 1f2e48bb071..bff4a208b86 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -1,16 +1,16 @@ use super::record_segment::ApplyMaterializedLogError; use super::{SegmentFlusher, SegmentWriter}; use async_trait::async_trait; +use chroma_distance::{DistanceFunction, DistanceFunctionError}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::{ HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexProviderOpenError, HnswIndexRef, }; -use chroma_index::{ - HnswIndexConfig, HnswIndexFromSegmentError, Index, IndexConfig, IndexConfigFromSegmentError, - IndexUuid, -}; -use chroma_types::{MaterializedLogOperation, Segment, SegmentUuid}; +use chroma_index::{Index, IndexUuid}; +use chroma_index::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; +use chroma_types::SegmentUuid; +use chroma_types::{get_metadata_value_as, MaterializedLogOperation, MetadataValue, Segment}; use std::collections::HashMap; use std::fmt::Debug; use thiserror::Error; @@ -18,6 +18,12 @@ use uuid::Uuid; const HNSW_INDEX: &str = "hnsw_index"; +pub struct HnswIndexParamsFromSegment { + pub m: usize, + pub ef_construction: usize, + pub ef_search: usize, +} + #[derive(Clone)] pub(crate) struct DistributedHNSWSegmentWriter { index: HnswIndexRef, @@ -41,16 +47,14 @@ pub enum DistributedHNSWSegmentFromSegmentError { InvalidUUID, #[error("HNSW segment uninitialized")] Uninitialized, - #[error("Index configuration error")] - IndexConfigError(#[from] IndexConfigFromSegmentError), - #[error("HNSW index configuration error")] - HnswIndexConfigError(#[from] HnswIndexFromSegmentError), #[error("HNSW index provider open error")] HnswIndexProviderOpenError(#[from] HnswIndexProviderOpenError), #[error("HNSW index provider fork error")] HnswIndexProviderForkError(#[from] HnswIndexProviderForkError), #[error("HNSW index provider create error")] HnswIndexProviderCreateError(#[from] HnswIndexProviderCreateError), + #[error("Error extracting distance function")] + DistanceFunctionError(#[from] DistanceFunctionError), } impl ChromaError for DistributedHNSWSegmentFromSegmentError { @@ -59,15 +63,64 @@ impl ChromaError for DistributedHNSWSegmentFromSegmentError { DistributedHNSWSegmentFromSegmentError::NoHnswFileFound => ErrorCodes::NotFound, DistributedHNSWSegmentFromSegmentError::InvalidUUID => ErrorCodes::InvalidArgument, DistributedHNSWSegmentFromSegmentError::Uninitialized => ErrorCodes::InvalidArgument, - DistributedHNSWSegmentFromSegmentError::IndexConfigError(e) => e.code(), - DistributedHNSWSegmentFromSegmentError::HnswIndexConfigError(e) => e.code(), DistributedHNSWSegmentFromSegmentError::HnswIndexProviderOpenError(e) => e.code(), DistributedHNSWSegmentFromSegmentError::HnswIndexProviderForkError(e) => e.code(), DistributedHNSWSegmentFromSegmentError::HnswIndexProviderCreateError(e) => e.code(), + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(e) => e.code(), } } } +fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParamsFromSegment { + let metadata = match &segment.metadata { + Some(metadata) => metadata, + None => { + return HnswIndexParamsFromSegment { + m: DEFAULT_HNSW_M, + ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, + ef_search: DEFAULT_HNSW_EF_SEARCH, + }; + } + }; + + let m = match get_metadata_value_as::(metadata, "hnsw:M") { + Ok(m) => m as usize, + Err(_) => DEFAULT_HNSW_M, + }; + let ef_construction = match get_metadata_value_as::(metadata, "hnsw:construction_ef") { + Ok(ef_construction) => ef_construction as usize, + Err(_) => DEFAULT_HNSW_EF_CONSTRUCTION, + }; + let ef_search = match get_metadata_value_as::(metadata, "hnsw:search_ef") { + Ok(ef_search) => ef_search as usize, + Err(_) => DEFAULT_HNSW_EF_SEARCH, + }; + + HnswIndexParamsFromSegment { + m, + ef_construction, + ef_search, + } +} + +pub fn distance_function_from_segment( + segment: &Segment, +) -> Result> { + let space = match segment.metadata { + Some(ref metadata) => match metadata.get("hnsw:space") { + Some(MetadataValue::Str(space)) => space, + _ => "l2", + }, + None => "l2", + }; + match DistanceFunction::try_from(space) { + Ok(distance_function) => Ok(distance_function), + Err(e) => Err(Box::new( + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(e), + )), + } +} + impl DistributedHNSWSegmentWriter { pub(crate) fn new( index: HnswIndexRef, @@ -87,25 +140,6 @@ impl DistributedHNSWSegmentWriter { hnsw_index_provider: HnswIndexProvider, ) -> Result, Box> { - let _index_config = match IndexConfig::from_segment(segment, dimensionality as i32) { - Ok(ic) => ic, - Err(e) => { - return Err(Box::new( - DistributedHNSWSegmentFromSegmentError::IndexConfigError(*e), - )); - } - }; - let persist_path = &hnsw_index_provider.temporary_storage_path; - - let _hnsw_config = match HnswIndexConfig::from_segment(segment, persist_path) { - Ok(hc) => hc, - Err(e) => { - return Err(Box::new( - DistributedHNSWSegmentFromSegmentError::HnswIndexConfigError(*e), - )); - } - }; - // TODO: this is hacky, we use the presence of files to determine if we need to load or create the index // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this @@ -139,8 +173,20 @@ impl DistributedHNSWSegmentWriter { }; let index_uuid = IndexUuid(index_uuid); + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(e); + } + }; + let index = match hnsw_index_provider - .fork(&index_uuid, segment, dimensionality as i32) + .fork( + &index_uuid, + &segment.collection, + dimensionality as i32, + distance_function, + ) .await { Ok(index) => index, @@ -157,8 +203,23 @@ impl DistributedHNSWSegmentWriter { segment.id, ))) } else { + let hnsw_params = hnsw_params_from_segment(segment); + + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(e); + } + }; let index = match hnsw_index_provider - .create(segment, dimensionality as i32) + .create( + &segment.collection, + hnsw_params.m, + hnsw_params.ef_construction, + hnsw_params.ef_search, + dimensionality as i32, + distance_function, + ) .await { Ok(index) => index, @@ -279,17 +340,6 @@ impl DistributedHNSWSegmentReader { hnsw_index_provider: HnswIndexProvider, ) -> Result, Box> { - let index_config = IndexConfig::from_segment(segment, dimensionality as i32); - let _index_config = match index_config { - Ok(ic) => ic, - Err(e) => { - return Err(Box::new( - DistributedHNSWSegmentFromSegmentError::IndexConfigError(*e), - )); - } - }; - let _persist_path = &hnsw_index_provider.temporary_storage_path; - // TODO: this is hacky, we use the presence of files to determine if we need to load or create the index // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this @@ -330,8 +380,19 @@ impl DistributedHNSWSegmentReader { { Some(index) => index, None => { + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(e); + } + }; match hnsw_index_provider - .open(&index_uuid, segment, dimensionality as i32) + .open( + &index_uuid, + &segment.collection, + dimensionality as i32, + distance_function, + ) .await { Ok(index) => index, @@ -365,3 +426,77 @@ impl DistributedHNSWSegmentReader { index.query(vector, k, allowed_ids, disallowd_ids) } } + +#[cfg(test)] +pub mod test { + use std::collections::HashMap; + + use chroma_index::{ + HnswIndexConfig, DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M, + DEFAULT_MAX_ELEMENTS, + }; + use chroma_types::{CollectionUuid, MetadataValue, Segment, SegmentUuid}; + use tempfile::tempdir; + use uuid::Uuid; + + use crate::segment::distributed_hnsw_segment::hnsw_params_from_segment; + + #[test] + fn parameter_defaults() { + let persist_path = tempdir().unwrap().path().to_owned(); + + let segment = Segment { + id: SegmentUuid(Uuid::new_v4()), + r#type: chroma_types::SegmentType::HnswDistributed, + scope: chroma_types::SegmentScope::VECTOR, + metadata: Some(HashMap::new()), + collection: CollectionUuid(Uuid::new_v4()), + file_path: HashMap::new(), + }; + + let hnsw_params = hnsw_params_from_segment(&segment); + let config = HnswIndexConfig::new( + hnsw_params.m, + hnsw_params.ef_construction, + hnsw_params.ef_search, + &persist_path, + ) + .expect("Error creating hnsw index config"); + + assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); + assert_eq!(config.m, DEFAULT_HNSW_M); + assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); + assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); + assert_eq!(config.random_seed, 0); + assert_eq!(config.persist_path, persist_path.to_str().unwrap()); + + // Try partial metadata + let mut metadata = HashMap::new(); + metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10_i64)); + + let segment = Segment { + id: SegmentUuid(Uuid::new_v4()), + r#type: chroma_types::SegmentType::HnswDistributed, + scope: chroma_types::SegmentScope::VECTOR, + metadata: Some(metadata), + collection: CollectionUuid(Uuid::new_v4()), + file_path: HashMap::new(), + }; + + let hnsw_params = hnsw_params_from_segment(&segment); + let config = HnswIndexConfig::new( + hnsw_params.m, + hnsw_params.ef_construction, + hnsw_params.ef_search, + &persist_path, + ) + .expect("Error creating hnsw index config"); + + assert_eq!(config.max_elements, DEFAULT_MAX_ELEMENTS); + assert_eq!(config.m, 10); + assert_eq!(config.ef_construction, DEFAULT_HNSW_EF_CONSTRUCTION); + assert_eq!(config.ef_search, DEFAULT_HNSW_EF_SEARCH); + assert_eq!(config.random_seed, 0); + assert_eq!(config.persist_path, persist_path.to_str().unwrap()); + } +}