diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index c8ab2f76d3f..ed5a529e274 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -8,7 +8,7 @@ use thiserror::Error; use uuid::Uuid; use crate::{ - hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}, + hnsw_provider::{HnswIndexProvider, HnswIndexRef}, IndexUuid, }; @@ -81,14 +81,16 @@ impl SpannIndexWriter { collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, - hnsw_params: HnswIndexParams, + m: usize, + ef_construction: usize, + ef_search: usize, ) -> Result { - let persist_path = &hnsw_provider.temporary_storage_path; match hnsw_provider .create( collection_id, - hnsw_params, - persist_path, + m, + ef_construction, + ef_search, dimensionality as i32, distance_function, ) @@ -155,7 +157,9 @@ impl SpannIndexWriter { hnsw_id: Option<&IndexUuid>, versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, - hnsw_params: Option, + m: Option, + ef_construction: Option, + ef_search: Option, collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, @@ -179,7 +183,9 @@ impl SpannIndexWriter { collection_id, distance_function, dimensionality, - hnsw_params.unwrap(), // Safe since caller should always provide this. + m.unwrap(), // Safe since caller should always provide this. + ef_construction.unwrap(), // Safe since caller should always provide this. + ef_search.unwrap(), // Safe since caller should always provide this. ) .await? } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 54c42385de5..f15bee2953e 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -69,7 +69,7 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::DistanceFunctionNotFound); } }; - let (hnsw_id, hnsw_params) = match segment.file_path.get(HNSW_PATH) { + let (hnsw_id, m, ef_construction, ef_search) = match segment.file_path.get(HNSW_PATH) { Some(hnsw_path) => match hnsw_path.first() { Some(index_id) => { let index_uuid = match Uuid::parse_str(index_id) { @@ -78,16 +78,19 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; + let hnsw_params = hnsw_params_from_segment(segment); ( Some(IndexUuid(index_uuid)), - Some(hnsw_params_from_segment(segment)), + Some(hnsw_params.m), + Some(hnsw_params.ef_construction), + Some(hnsw_params.ef_search), ) } None => { return Err(SpannSegmentWriterError::HnswInvalidFilePath); } }, - None => (None, None), + None => (None, None, None, None), }; let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) { Some(version_map_path) => match version_map_path.first() { @@ -129,7 +132,9 @@ impl SpannSegmentWriter { hnsw_id.as_ref(), versions_map_id.as_ref(), posting_list_id.as_ref(), - hnsw_params, + m, + ef_construction, + ef_search, &segment.collection, distance_function, dimensionality, diff --git a/rust/worker/src/segment/utils.rs b/rust/worker/src/segment/utils.rs index 08b76b17ce6..b5305eecdfb 100644 --- a/rust/worker/src/segment/utils.rs +++ b/rust/worker/src/segment/utils.rs @@ -1,19 +1,18 @@ use chroma_distance::{DistanceFunction, DistanceFunctionError}; -use chroma_index::{ - hnsw_provider::HnswIndexParams, DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, - DEFAULT_HNSW_M, -}; +use chroma_index::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; use chroma_types::{get_metadata_value_as, MetadataValue, Segment}; -pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { +use super::distributed_hnsw_segment::HnswIndexParamsFromSegment; + +pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParamsFromSegment { let metadata = match &segment.metadata { Some(metadata) => metadata, None => { - return ( - DEFAULT_HNSW_M, - DEFAULT_HNSW_EF_CONSTRUCTION, - DEFAULT_HNSW_EF_SEARCH, - ); + return HnswIndexParamsFromSegment { + m: DEFAULT_HNSW_M, + ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, + ef_search: DEFAULT_HNSW_EF_SEARCH, + }; } }; @@ -30,7 +29,11 @@ pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { Err(_) => DEFAULT_HNSW_EF_SEARCH, }; - (m, ef_construction, ef_search) + HnswIndexParamsFromSegment { + m, + ef_construction, + ef_search, + } } pub(crate) fn distance_function_from_segment(