From f3a6e09032b8027b1cd2b9e620a5dfb0ad2c04af Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Tue, 5 Nov 2024 11:08:38 -0800 Subject: [PATCH 01/10] Spann_segment_impl --- rust/blockstore/src/arrow/block/types.rs | 20 +++ rust/blockstore/src/arrow/blockfile.rs | 19 +++ rust/blockstore/src/arrow/sparse_index.rs | 4 + rust/blockstore/src/types/reader.rs | 7 ++ rust/index/src/hnsw.rs | 10 +- rust/index/src/hnsw_provider.rs | 17 +-- rust/index/src/lib.rs | 2 + rust/index/src/spann.rs | 1 + rust/index/src/spann/types.rs | 99 +++++++++++++++ rust/index/src/types.rs | 14 --- .../src/execution/orchestration/hnsw.rs | 4 +- .../src/segment/distributed_hnsw_segment.rs | 116 ++++++----------- rust/worker/src/segment/mod.rs | 2 + rust/worker/src/segment/spann_segment.rs | 118 ++++++++++++++++++ rust/worker/src/segment/utils.rs | 50 ++++++++ 15 files changed, 367 insertions(+), 116 deletions(-) create mode 100644 rust/index/src/spann.rs create mode 100644 rust/index/src/spann/types.rs create mode 100644 rust/worker/src/segment/spann_segment.rs create mode 100644 rust/worker/src/segment/utils.rs diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 6764591c86f..34a2f1ec453 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -346,6 +346,26 @@ impl Block { } } + pub fn get_all_data<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( + &'me self, + ) -> Vec<(&'me str, K, V)> { + let prefix_arr = self + .data + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let mut result = Vec::new(); + for i in 0..self.data.num_rows() { + result.push(( + prefix_arr.value(i), + K::get(self.data.column(1), i), + V::get(self.data.column(2), i), + )); + } + result + } + /// Get all the values for a given prefix & key range in the block /// ### Panics /// - If the underlying data types are not the same as the types specified in the function signature diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 44e2db95587..2e7c7b4900f 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -682,6 +682,25 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me true } + + pub async fn get_all_data(&'me self) -> Vec<(&'me str, K, V)> { + let block_ids = self.root.sparse_index.get_all_block_ids(); + let mut result = vec![]; + for block_id in block_ids { + let block = match self.get_block(block_id).await { + Ok(Some(block)) => block, + Ok(None) => { + continue; + } + Err(_) => { + continue; + } + }; + + result.extend(block.get_all_data()); + } + return result; + } } #[cfg(test)] diff --git a/rust/blockstore/src/arrow/sparse_index.rs b/rust/blockstore/src/arrow/sparse_index.rs index 3d4124ab447..3a7cac2c9ed 100644 --- a/rust/blockstore/src/arrow/sparse_index.rs +++ b/rust/blockstore/src/arrow/sparse_index.rs @@ -321,6 +321,10 @@ impl SparseIndexReader { get_target_block(search_key, forward).id } + pub(super) fn get_all_block_ids(&self) -> Vec { + self.data.forward.values().map(|v| v.id).collect() + } + /// Get all the block ids that contain keys in the given input search keys pub(super) fn get_all_target_block_ids(&self, mut search_keys: Vec) -> Vec { // Sort so that we can search in one iteration. diff --git a/rust/blockstore/src/types/reader.rs b/rust/blockstore/src/types/reader.rs index 43311d979d4..926bcad4b9d 100644 --- a/rust/blockstore/src/types/reader.rs +++ b/rust/blockstore/src/types/reader.rs @@ -131,4 +131,11 @@ impl< } } } + + pub async fn get_all_data(&'referred_data self) -> Vec<(&'referred_data str, K, V)> { + match self { + BlockfileReader::MemoryBlockfileReader(reader) => todo!(), + BlockfileReader::ArrowBlockfileReader(reader) => reader.get_all_data().await, + } + } } diff --git a/rust/index/src/hnsw.rs b/rust/index/src/hnsw.rs index e97b04a28fe..f95c19732ae 100644 --- a/rust/index/src/hnsw.rs +++ b/rust/index/src/hnsw.rs @@ -37,14 +37,12 @@ pub struct HnswIndexConfig { } #[derive(Error, Debug)] -pub enum HnswIndexFromSegmentError { +pub enum HnswIndexConfigError { #[error("Missing config `{0}`")] MissingConfig(String), - #[error("Invalid metadata value")] - MetadataValueError(#[from] MetadataValueConversionError), } -impl ChromaError for HnswIndexFromSegmentError { +impl ChromaError for HnswIndexConfigError { fn code(&self) -> ErrorCodes { ErrorCodes::InvalidArgument } @@ -56,11 +54,11 @@ impl HnswIndexConfig { ef_construction: usize, ef_search: usize, persist_path: &Path, - ) -> Result> { + ) -> Result> { let persist_path = match persist_path.to_str() { Some(persist_path) => persist_path, None => { - return Err(Box::new(HnswIndexFromSegmentError::MissingConfig( + return Err(Box::new(HnswIndexConfigError::MissingConfig( "persist_path".to_string(), ))) } diff --git a/rust/index/src/hnsw_provider.rs b/rust/index/src/hnsw_provider.rs index 9d40aa5da52..6fae49e2d29 100644 --- a/rust/index/src/hnsw_provider.rs +++ b/rust/index/src/hnsw_provider.rs @@ -547,12 +547,8 @@ impl HnswIndexProvider { #[derive(Error, Debug)] pub enum HnswIndexProviderOpenError { - #[error("Index configuration error")] - IndexConfigError(#[from] IndexConfigFromSegmentError), #[error("Hnsw index file error")] FileError(#[from] HnswIndexProviderFileError), - #[error("Hnsw config error")] - HnswConfigError(#[from] HnswIndexFromSegmentError), #[error("Index load error")] IndexLoadError(#[from] Box), #[error("Path: {0} could not be converted to string")] @@ -562,9 +558,7 @@ pub enum HnswIndexProviderOpenError { impl ChromaError for HnswIndexProviderOpenError { fn code(&self) -> ErrorCodes { match self { - HnswIndexProviderOpenError::IndexConfigError(e) => e.code(), HnswIndexProviderOpenError::FileError(_) => ErrorCodes::Internal, - HnswIndexProviderOpenError::HnswConfigError(e) => e.code(), HnswIndexProviderOpenError::IndexLoadError(e) => e.code(), HnswIndexProviderOpenError::PathToStringError(_) => ErrorCodes::InvalidArgument, } @@ -573,12 +567,8 @@ impl ChromaError for HnswIndexProviderOpenError { #[derive(Error, Debug)] pub enum HnswIndexProviderForkError { - #[error("Index configuration error")] - IndexConfigError(#[from] IndexConfigFromSegmentError), #[error("Hnsw index file error")] FileError(#[from] HnswIndexProviderFileError), - #[error("Hnsw config error")] - HnswConfigError(#[from] HnswIndexFromSegmentError), #[error("Index load error")] IndexLoadError(#[from] Box), #[error("Path: {0} could not be converted to string")] @@ -588,9 +578,7 @@ pub enum HnswIndexProviderForkError { impl ChromaError for HnswIndexProviderForkError { fn code(&self) -> ErrorCodes { match self { - HnswIndexProviderForkError::IndexConfigError(e) => e.code(), HnswIndexProviderForkError::FileError(_) => ErrorCodes::Internal, - HnswIndexProviderForkError::HnswConfigError(e) => e.code(), HnswIndexProviderForkError::IndexLoadError(e) => e.code(), HnswIndexProviderForkError::PathToStringError(_) => ErrorCodes::InvalidArgument, } @@ -599,12 +587,10 @@ impl ChromaError for HnswIndexProviderForkError { #[derive(Error, Debug)] pub enum HnswIndexProviderCreateError { - #[error("Index configuration error")] - IndexConfigError(#[from] IndexConfigFromSegmentError), #[error("Hnsw index file error")] FileError(#[from] HnswIndexProviderFileError), #[error("Hnsw config error")] - HnswConfigError(#[from] HnswIndexFromSegmentError), + HnswConfigError(#[from] HnswIndexConfigError), #[error("Index init error")] IndexInitError(#[from] Box), } @@ -612,7 +598,6 @@ pub enum HnswIndexProviderCreateError { impl ChromaError for HnswIndexProviderCreateError { fn code(&self) -> ErrorCodes { match self { - HnswIndexProviderCreateError::IndexConfigError(e) => e.code(), HnswIndexProviderCreateError::FileError(_) => ErrorCodes::Internal, HnswIndexProviderCreateError::HnswConfigError(e) => e.code(), HnswIndexProviderCreateError::IndexInitError(e) => e.code(), diff --git a/rust/index/src/lib.rs b/rust/index/src/lib.rs index 43ae6ba7288..ac9d9c585cd 100644 --- a/rust/index/src/lib.rs +++ b/rust/index/src/lib.rs @@ -3,6 +3,7 @@ pub mod fulltext; mod hnsw; pub mod hnsw_provider; pub mod metadata; +pub mod spann; mod types; pub mod utils; @@ -12,6 +13,7 @@ use chroma_cache::new_non_persistent_cache_for_test; use chroma_storage::test_storage; pub use hnsw::*; use hnsw_provider::HnswIndexProvider; +pub use spann::*; use tempfile::tempdir; pub use types::*; diff --git a/rust/index/src/spann.rs b/rust/index/src/spann.rs new file mode 100644 index 00000000000..dd198c6d01e --- /dev/null +++ b/rust/index/src/spann.rs @@ -0,0 +1 @@ +pub mod types; \ No newline at end of file diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs new file mode 100644 index 00000000000..dac81a7863b --- /dev/null +++ b/rust/index/src/spann/types.rs @@ -0,0 +1,99 @@ +use std::collections::HashMap; + +use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter}; +use chroma_distance::DistanceFunction; +use chroma_error::{ChromaError, ErrorCodes}; +use thiserror::Error; +use uuid::Uuid; + +use crate::hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}; + +// TODO(Sanket): Add locking structures as necessary. +pub struct SpannIndexWriter { + // HNSW index and its provider for centroid search. + hnsw_index: HnswIndexRef, + hnsw_provider: HnswIndexProvider, + // Posting list of the centroids. + // The blockfile also contains next id for the head. + posting_list_writer: BlockfileWriter, + // Version number of each point. + versions_map: HashMap, +} + +#[derive(Error, Debug)] +pub enum SpannIndexWriterConstructionError { + #[error("HNSW index construction error")] + HnswIndexConstructionError, + #[error("Blockfile reader construction error")] + BlockfileReaderConstructionError, +} + +impl ChromaError for SpannIndexWriterConstructionError { + fn code(&self) -> ErrorCodes { + match self { + Self::HnswIndexConstructionError => ErrorCodes::Internal, + Self::BlockfileReaderConstructionError => ErrorCodes::Internal, + } + } +} + +impl SpannIndexWriter { + pub async fn hnsw_index_from_id( + hnsw_provider: &HnswIndexProvider, + id: &Uuid, + collection_id: &Uuid, + distance_function: DistanceFunction, + dimensionality: usize, + ) -> Result { + match hnsw_provider + .fork(id, collection_id, dimensionality as i32, distance_function) + .await + { + Ok(index) => Ok(index), + Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError), + } + } + + pub async fn create_hnsw_index( + hnsw_provider: &HnswIndexProvider, + collection_id: &Uuid, + distance_function: DistanceFunction, + dimensionality: usize, + hnsw_params: HnswIndexParams, + ) -> Result { + let persist_path = &hnsw_provider.temporary_storage_path; + match hnsw_provider + .create( + collection_id, + hnsw_params, + persist_path, + dimensionality as i32, + distance_function, + ) + .await + { + Ok(index) => Ok(index), + Err(_) => Err(SpannIndexWriterConstructionError::HnswIndexConstructionError), + } + } + + pub async fn load_versions_map( + blockfile_id: &Uuid, + blockfile_provider: &BlockfileProvider, + ) -> Result, SpannIndexWriterConstructionError> { + // Create a reader for the blockfile. Load all the data into the versions map. + let mut versions_map = HashMap::new(); + let reader = match blockfile_provider.open::(blockfile_id).await { + Ok(reader) => reader, + Err(_) => { + return Err(SpannIndexWriterConstructionError::BlockfileReaderConstructionError) + } + }; + // Load data using the reader. + let versions_data = reader.get_all_data().await; + versions_data.iter().for_each(|(_, key, value)| { + versions_map.insert(*key, *value); + }); + Ok(versions_map) + } +} diff --git a/rust/index/src/types.rs b/rust/index/src/types.rs index f34db6d83ac..e906ef84376 100644 --- a/rust/index/src/types.rs +++ b/rust/index/src/types.rs @@ -9,20 +9,6 @@ pub struct IndexConfig { pub distance_function: DistanceFunction, } -#[derive(Error, Debug)] -pub enum IndexConfigFromSegmentError { - #[error("Invalid distance function")] - InvalidDistanceFunction(#[from] DistanceFunctionError), -} - -impl ChromaError for IndexConfigFromSegmentError { - fn code(&self) -> ErrorCodes { - match self { - IndexConfigFromSegmentError::InvalidDistanceFunction(_) => ErrorCodes::InvalidArgument, - } - } -} - impl IndexConfig { pub fn new(dimensionality: i32, distance_function: DistanceFunction) -> Self { IndexConfig { diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index fbc19072f8c..7cc7c955815 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -26,9 +26,9 @@ use crate::execution::operators::record_segment_prefetch::{ }; use crate::log::log::PullLogsError; use crate::segment::distributed_hnsw_segment::{ - distance_function_from_segment, DistributedHNSWSegmentFromSegmentError, - DistributedHNSWSegmentReader, + DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader, }; +use crate::segment::utils::distance_function_from_segment; use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}; use crate::system::{ComponentContext, ComponentHandle, System}; use crate::{ diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index bff4a208b86..936e87d7cb2 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -1,4 +1,7 @@ +use crate::segment::utils::distance_function_from_segment; + use super::record_segment::ApplyMaterializedLogError; +use super::utils::hnsw_params_from_segment; use super::{SegmentFlusher, SegmentWriter}; use async_trait::async_trait; use chroma_distance::{DistanceFunction, DistanceFunctionError}; @@ -71,56 +74,6 @@ impl ChromaError for DistributedHNSWSegmentFromSegmentError { } } -fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParamsFromSegment { - let metadata = match &segment.metadata { - Some(metadata) => metadata, - None => { - return HnswIndexParamsFromSegment { - m: DEFAULT_HNSW_M, - ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, - ef_search: DEFAULT_HNSW_EF_SEARCH, - }; - } - }; - - let m = match get_metadata_value_as::(metadata, "hnsw:M") { - Ok(m) => m as usize, - Err(_) => DEFAULT_HNSW_M, - }; - let ef_construction = match get_metadata_value_as::(metadata, "hnsw:construction_ef") { - Ok(ef_construction) => ef_construction as usize, - Err(_) => DEFAULT_HNSW_EF_CONSTRUCTION, - }; - let ef_search = match get_metadata_value_as::(metadata, "hnsw:search_ef") { - Ok(ef_search) => ef_search as usize, - Err(_) => DEFAULT_HNSW_EF_SEARCH, - }; - - HnswIndexParamsFromSegment { - m, - ef_construction, - ef_search, - } -} - -pub fn distance_function_from_segment( - segment: &Segment, -) -> Result> { - let space = match segment.metadata { - Some(ref metadata) => match metadata.get("hnsw:space") { - Some(MetadataValue::Str(space)) => space, - _ => "l2", - }, - None => "l2", - }; - match DistanceFunction::try_from(space) { - Ok(distance_function) => Ok(distance_function), - Err(e) => Err(Box::new( - DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(e), - )), - } -} - impl DistributedHNSWSegmentWriter { pub(crate) fn new( index: HnswIndexRef, @@ -176,7 +129,9 @@ impl DistributedHNSWSegmentWriter { let distance_function = match distance_function_from_segment(segment) { Ok(distance_function) => distance_function, Err(e) => { - return Err(e); + return Err(Box::new( + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(*e), + )); } }; @@ -208,7 +163,9 @@ impl DistributedHNSWSegmentWriter { let distance_function = match distance_function_from_segment(segment) { Ok(distance_function) => distance_function, Err(e) => { - return Err(e); + return Err(Box::new( + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(*e), + )); } }; let index = match hnsw_index_provider @@ -373,37 +330,40 @@ impl DistributedHNSWSegmentReader { }; let index_uuid = IndexUuid(index_uuid); - let index = - match hnsw_index_provider - .get(&index_uuid, &segment.collection) - .await - { - Some(index) => index, - None => { - let distance_function = match distance_function_from_segment(segment) { - Ok(distance_function) => distance_function, - Err(e) => { - return Err(e); - } - }; - match hnsw_index_provider - .open( - &index_uuid, - &segment.collection, - dimensionality as i32, - distance_function, - ) - .await - { - Ok(index) => index, - Err(e) => return Err(Box::new( + let index = match hnsw_index_provider + .get(&index_uuid, &segment.collection) + .await + { + Some(index) => index, + None => { + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(Box::new( + DistributedHNSWSegmentFromSegmentError::DistanceFunctionError(*e), + )); + } + }; + match hnsw_index_provider + .open( + &index_uuid, + &segment.collection, + dimensionality as i32, + distance_function, + ) + .await + { + Ok(index) => index, + Err(e) => { + return Err(Box::new( DistributedHNSWSegmentFromSegmentError::HnswIndexProviderOpenError( *e, ), - )), + )) } } - }; + } + }; Ok(Box::new(DistributedHNSWSegmentReader::new( index, segment.id, diff --git a/rust/worker/src/segment/mod.rs b/rust/worker/src/segment/mod.rs index 2837ac11126..14920256194 100644 --- a/rust/worker/src/segment/mod.rs +++ b/rust/worker/src/segment/mod.rs @@ -1,10 +1,12 @@ pub(crate) mod config; pub(crate) mod distributed_hnsw_segment; pub mod test; +pub(crate) mod utils; pub(crate) use types::*; // Required for benchmark pub mod metadata_segment; pub mod record_segment; +pub mod spann_segment; pub mod types; diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs new file mode 100644 index 00000000000..8fe18fb247b --- /dev/null +++ b/rust/worker/src/segment/spann_segment.rs @@ -0,0 +1,118 @@ +use chroma_blockstore::provider::BlockfileProvider; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; +use chroma_types::{Segment, SegmentType}; +use thiserror::Error; +use uuid::Uuid; + +use super::utils::{distance_function_from_segment, hnsw_params_from_segment}; + +pub(crate) struct SpannSegmentWriter { + index: SpannIndexWriter, + id: Uuid, +} + +#[derive(Error, Debug)] +pub enum SpannSegmentWriterError { + #[error("Invalid argument")] + InvalidArgument, + #[error("Distance function not found")] + DistanceFunctionNotFound, + #[error("Hnsw index id parsing error")] + IndexIdParsingError, + #[error("HNSW index construction error")] + HnswIndexConstructionError, +} + +impl ChromaError for SpannSegmentWriterError { + fn code(&self) -> ErrorCodes { + match self { + Self::InvalidArgument => ErrorCodes::InvalidArgument, + Self::IndexIdParsingError => ErrorCodes::Internal, + Self::HnswIndexConstructionError => ErrorCodes::Internal, + Self::DistanceFunctionNotFound => ErrorCodes::Internal, + } + } +} + +impl SpannSegmentWriter { + pub async fn from_segment( + segment: &Segment, + blockfile_provider: &BlockfileProvider, + hnsw_provider: HnswIndexProvider, + dimensionality: usize, + ) -> Result { + // TODO(Sanket): Introduce another segment type and propagate here. + if segment.r#type != SegmentType::HnswDistributed { + return Err(SpannSegmentWriterError::InvalidArgument); + } + match segment.file_path.get("hnsw_path") { + Some(hnsw_path) => match hnsw_path.first() { + Some(index_id) => { + let index_uuid = match Uuid::parse_str(index_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentWriterError::IndexIdParsingError); + } + }; + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(SpannSegmentWriterError::DistanceFunctionNotFound); + } + }; + let hnsw_index = match SpannIndexWriter::hnsw_index_from_id( + &hnsw_provider, + &index_uuid, + &segment.collection, + distance_function, + dimensionality, + ) + .await + { + Ok(index) => index, + Err(_) => { + return Err(SpannSegmentWriterError::HnswIndexConstructionError); + } + }; + + // TODO(Sanket): Remove this. + return Err(SpannSegmentWriterError::InvalidArgument); + } + // TODO: Create index in this case also. + None => { + return Err(SpannSegmentWriterError::InvalidArgument); + } + }, + // TODO(Sanket): Create index in this case. + None => { + let hnsw_params = hnsw_params_from_segment(segment); + + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(SpannSegmentWriterError::DistanceFunctionNotFound); + } + }; + + let hnsw_index = match SpannIndexWriter::create_hnsw_index( + &hnsw_provider, + &segment.collection, + distance_function, + dimensionality, + hnsw_params, + ) + .await + { + Ok(index) => index, + Err(_) => { + return Err(SpannSegmentWriterError::HnswIndexConstructionError); + } + }; + + // First time creation of the segment. + return Err(SpannSegmentWriterError::InvalidArgument); + } + } + } +} diff --git a/rust/worker/src/segment/utils.rs b/rust/worker/src/segment/utils.rs new file mode 100644 index 00000000000..08b76b17ce6 --- /dev/null +++ b/rust/worker/src/segment/utils.rs @@ -0,0 +1,50 @@ +use chroma_distance::{DistanceFunction, DistanceFunctionError}; +use chroma_index::{ + hnsw_provider::HnswIndexParams, DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, + DEFAULT_HNSW_M, +}; +use chroma_types::{get_metadata_value_as, MetadataValue, Segment}; + +pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { + let metadata = match &segment.metadata { + Some(metadata) => metadata, + None => { + return ( + DEFAULT_HNSW_M, + DEFAULT_HNSW_EF_CONSTRUCTION, + DEFAULT_HNSW_EF_SEARCH, + ); + } + }; + + let m = match get_metadata_value_as::(metadata, "hnsw:M") { + Ok(m) => m as usize, + Err(_) => DEFAULT_HNSW_M, + }; + let ef_construction = match get_metadata_value_as::(metadata, "hnsw:construction_ef") { + Ok(ef_construction) => ef_construction as usize, + Err(_) => DEFAULT_HNSW_EF_CONSTRUCTION, + }; + let ef_search = match get_metadata_value_as::(metadata, "hnsw:search_ef") { + Ok(ef_search) => ef_search as usize, + Err(_) => DEFAULT_HNSW_EF_SEARCH, + }; + + (m, ef_construction, ef_search) +} + +pub(crate) fn distance_function_from_segment( + segment: &Segment, +) -> Result> { + let space = match segment.metadata { + Some(ref metadata) => match metadata.get("hnsw:space") { + Some(MetadataValue::Str(space)) => space, + _ => "l2", + }, + None => "l2", + }; + match DistanceFunction::try_from(space) { + Ok(distance_function) => Ok(distance_function), + Err(e) => Err(Box::new(e)), + } +} From b9528d8f549f636f8b8341ee18f8e07883cc3667 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 6 Nov 2024 11:00:06 -0800 Subject: [PATCH 02/10] Add from_segment impl --- rust/blockstore/src/memory/storage.rs | 12 ++- rust/index/src/spann/types.rs | 41 +++++++ rust/types/src/segment.rs | 14 +-- rust/worker/src/segment/spann_segment.rs | 129 +++++++++++++++++++---- 4 files changed, 163 insertions(+), 33 deletions(-) diff --git a/rust/blockstore/src/memory/storage.rs b/rust/blockstore/src/memory/storage.rs index b2bb2818946..29e7fb251fd 100644 --- a/rust/blockstore/src/memory/storage.rs +++ b/rust/blockstore/src/memory/storage.rs @@ -1,6 +1,6 @@ use crate::key::{CompositeKey, KeyWrapper}; use chroma_error::ChromaError; -use chroma_types::DataRecord; +use chroma_types::{DataRecord, SpannPostingList}; use parking_lot::RwLock; use roaring::RoaringBitmap; use std::{ @@ -585,6 +585,16 @@ impl Writeable for &DataRecord<'_> { } } +impl Writeable for &SpannPostingList<'_> { + fn write_to_storage(prefix: &str, key: KeyWrapper, value: Self, storage: &StorageBuilder) { + todo!() + } + + fn remove_from_storage(prefix: &str, key: KeyWrapper, storage: &StorageBuilder) { + todo!() + } +} + impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> { fn read_from_storage( prefix: &str, diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index dac81a7863b..ae7dbdf874d 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; +use arrow::error; use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter}; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; +use chroma_types::SpannPostingList; use thiserror::Error; use uuid::Uuid; @@ -26,6 +28,8 @@ pub enum SpannIndexWriterConstructionError { HnswIndexConstructionError, #[error("Blockfile reader construction error")] BlockfileReaderConstructionError, + #[error("Blockfile writer construction error")] + BlockfileWriterConstructionError, } impl ChromaError for SpannIndexWriterConstructionError { @@ -33,11 +37,26 @@ impl ChromaError for SpannIndexWriterConstructionError { match self { Self::HnswIndexConstructionError => ErrorCodes::Internal, Self::BlockfileReaderConstructionError => ErrorCodes::Internal, + Self::BlockfileWriterConstructionError => ErrorCodes::Internal, } } } impl SpannIndexWriter { + pub fn new( + hnsw_index: HnswIndexRef, + hnsw_provider: HnswIndexProvider, + posting_list_writer: BlockfileWriter, + versions_map: HashMap, + ) -> Self { + SpannIndexWriter { + hnsw_index, + hnsw_provider, + posting_list_writer, + versions_map, + } + } + pub async fn hnsw_index_from_id( hnsw_provider: &HnswIndexProvider, id: &Uuid, @@ -96,4 +115,26 @@ impl SpannIndexWriter { }); Ok(versions_map) } + + pub async fn fork_postings_list( + blockfile_id: &Uuid, + blockfile_provider: &BlockfileProvider, + ) -> Result { + match blockfile_provider + .fork::>(blockfile_id) + .await + { + Ok(writer) => Ok(writer), + Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), + } + } + + pub async fn create_posting_list( + blockfile_provider: &BlockfileProvider, + ) -> Result { + match blockfile_provider.create::>() { + Ok(writer) => Ok(writer), + Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), + } + } } diff --git a/rust/types/src/segment.rs b/rust/types/src/segment.rs index e1465c56783..29fba56af84 100644 --- a/rust/types/src/segment.rs +++ b/rust/types/src/segment.rs @@ -41,6 +41,7 @@ pub enum SegmentType { BlockfileMetadata, BlockfileRecord, Sqlite, + Spann, } impl From for String { @@ -52,6 +53,7 @@ impl From for String { SegmentType::BlockfileRecord => "urn:chroma:segment/record/blockfile".to_string(), SegmentType::Sqlite => "urn:chroma:segment/metadata/sqlite".to_string(), SegmentType::BlockfileMetadata => "urn:chroma:segment/metadata/blockfile".to_string(), + SegmentType::Spann => "urn:chroma:segment/vector/spann".to_string(), } } } @@ -65,6 +67,7 @@ impl TryFrom<&str> for SegmentType { "urn:chroma:segment/record/blockfile" => Ok(SegmentType::BlockfileRecord), "urn:chroma:segment/metadata/sqlite" => Ok(SegmentType::Sqlite), "urn:chroma:segment/metadata/blockfile" => Ok(SegmentType::BlockfileMetadata), + "urn:chroma:segment/vector/spann" => Ok(SegmentType::Spann), _ => Err(SegmentConversionError::InvalidSegmentType), } } @@ -130,16 +133,7 @@ impl TryFrom for Segment { Err(e) => return Err(SegmentConversionError::SegmentScopeConversionError(e)), }; - let segment_type = match proto_segment.r#type.as_str() { - "urn:chroma:segment/vector/hnsw-distributed" => SegmentType::HnswDistributed, - "urn:chroma:segment/record/blockfile" => SegmentType::BlockfileRecord, - "urn:chroma:segment/metadata/sqlite" => SegmentType::Sqlite, - "urn:chroma:segment/metadata/blockfile" => SegmentType::BlockfileMetadata, - _ => { - println!("Invalid segment type: {}", proto_segment.r#type); - return Err(SegmentConversionError::InvalidSegmentType); - } - }; + let segment_type: SegmentType = proto_segment.r#type.as_str().try_into()?; let mut file_paths = HashMap::new(); let drain = proto_segment.file_paths.drain(); diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 8fe18fb247b..1448d7af698 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,12 +1,19 @@ +use std::collections::HashMap; + +use arrow::error; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; -use chroma_types::{Segment, SegmentType}; +use chroma_types::{Segment, SegmentScope, SegmentType}; use thiserror::Error; use uuid::Uuid; use super::utils::{distance_function_from_segment, hnsw_params_from_segment}; +const HNSW_PATH: &str = "hnsw_path"; +const VERSION_MAP_PATH: &str = "version_map_path"; +const POSTING_LIST_PATH: &str = "posting_list_path"; + pub(crate) struct SpannSegmentWriter { index: SpannIndexWriter, id: Uuid, @@ -20,8 +27,20 @@ pub enum SpannSegmentWriterError { DistanceFunctionNotFound, #[error("Hnsw index id parsing error")] IndexIdParsingError, - #[error("HNSW index construction error")] - HnswIndexConstructionError, + #[error("HNSW index creation error")] + HnswIndexCreationError, + #[error("Hnsw Invalid file path")] + HnswInvalidFilePath, + #[error("Version map Invalid file path")] + VersionMapInvalidFilePath, + #[error("Failure in loading the versions map")] + VersionMapLoadError, + #[error("Failure in forking the posting list")] + PostingListForkError, + #[error("Postings list invalid file path")] + PostingListInvalidFilePath, + #[error("Posting list creation error")] + PostingListCreationError, } impl ChromaError for SpannSegmentWriterError { @@ -29,8 +48,14 @@ impl ChromaError for SpannSegmentWriterError { match self { Self::InvalidArgument => ErrorCodes::InvalidArgument, Self::IndexIdParsingError => ErrorCodes::Internal, - Self::HnswIndexConstructionError => ErrorCodes::Internal, + Self::HnswIndexCreationError => ErrorCodes::Internal, Self::DistanceFunctionNotFound => ErrorCodes::Internal, + Self::HnswInvalidFilePath => ErrorCodes::Internal, + Self::VersionMapInvalidFilePath => ErrorCodes::Internal, + Self::VersionMapLoadError => ErrorCodes::Internal, + Self::PostingListForkError => ErrorCodes::Internal, + Self::PostingListInvalidFilePath => ErrorCodes::Internal, + Self::PostingListCreationError => ErrorCodes::Internal, } } } @@ -42,11 +67,11 @@ impl SpannSegmentWriter { hnsw_provider: HnswIndexProvider, dimensionality: usize, ) -> Result { - // TODO(Sanket): Introduce another segment type and propagate here. - if segment.r#type != SegmentType::HnswDistributed { + if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR { return Err(SpannSegmentWriterError::InvalidArgument); } - match segment.file_path.get("hnsw_path") { + // Load HNSW index. + let hnsw_index = match segment.file_path.get(HNSW_PATH) { Some(hnsw_path) => match hnsw_path.first() { Some(index_id) => { let index_uuid = match Uuid::parse_str(index_id) { @@ -61,7 +86,7 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::DistanceFunctionNotFound); } }; - let hnsw_index = match SpannIndexWriter::hnsw_index_from_id( + match SpannIndexWriter::hnsw_index_from_id( &hnsw_provider, &index_uuid, &segment.collection, @@ -72,19 +97,15 @@ impl SpannSegmentWriter { { Ok(index) => index, Err(_) => { - return Err(SpannSegmentWriterError::HnswIndexConstructionError); + return Err(SpannSegmentWriterError::HnswIndexCreationError); } - }; - - // TODO(Sanket): Remove this. - return Err(SpannSegmentWriterError::InvalidArgument); + } } - // TODO: Create index in this case also. None => { - return Err(SpannSegmentWriterError::InvalidArgument); + return Err(SpannSegmentWriterError::HnswInvalidFilePath); } }, - // TODO(Sanket): Create index in this case. + // Create a new index. None => { let hnsw_params = hnsw_params_from_segment(segment); @@ -95,7 +116,7 @@ impl SpannSegmentWriter { } }; - let hnsw_index = match SpannIndexWriter::create_hnsw_index( + match SpannIndexWriter::create_hnsw_index( &hnsw_provider, &segment.collection, distance_function, @@ -106,13 +127,77 @@ impl SpannSegmentWriter { { Ok(index) => index, Err(_) => { - return Err(SpannSegmentWriterError::HnswIndexConstructionError); + return Err(SpannSegmentWriterError::HnswIndexCreationError); } - }; - - // First time creation of the segment. - return Err(SpannSegmentWriterError::InvalidArgument); + } + } + }; + // Load version map. Empty if file path is not set. + let mut version_map = HashMap::new(); + if let Some(version_map_path) = segment.file_path.get(VERSION_MAP_PATH) { + version_map = match version_map_path.first() { + Some(version_map_id) => { + let version_map_uuid = match Uuid::parse_str(version_map_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentWriterError::IndexIdParsingError); + } + }; + match SpannIndexWriter::load_versions_map(&version_map_uuid, blockfile_provider) + .await + { + Ok(index) => index, + Err(_) => { + return Err(SpannSegmentWriterError::VersionMapLoadError); + } + } + } + None => { + return Err(SpannSegmentWriterError::VersionMapInvalidFilePath); + } } } + // Fork the posting list map. + let posting_list_writer = match segment.file_path.get(POSTING_LIST_PATH) { + Some(posting_list_path) => match posting_list_path.first() { + Some(posting_list_id) => { + let posting_list_uuid = match Uuid::parse_str(posting_list_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(SpannSegmentWriterError::IndexIdParsingError); + } + }; + match SpannIndexWriter::fork_postings_list( + &posting_list_uuid, + blockfile_provider, + ) + .await + { + Ok(writer) => writer, + Err(_) => { + return Err(SpannSegmentWriterError::PostingListForkError); + } + } + } + None => { + return Err(SpannSegmentWriterError::PostingListInvalidFilePath); + } + }, + // Create a new index. + None => match SpannIndexWriter::create_posting_list(blockfile_provider).await { + Ok(writer) => writer, + Err(_) => { + return Err(SpannSegmentWriterError::PostingListCreationError); + } + }, + }; + + let index_writer = + SpannIndexWriter::new(hnsw_index, hnsw_provider, posting_list_writer, version_map); + + Ok(SpannSegmentWriter { + index: index_writer, + id: segment.id, + }) } } From 68adb3b690062e1245a522b0fa0ce0e0ab6c6d3e Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 6 Nov 2024 13:44:43 -0800 Subject: [PATCH 03/10] Better abstraction --- rust/index/src/spann/types.rs | 67 +++++++++++- rust/worker/src/segment/spann_segment.rs | 134 +++++++---------------- 2 files changed, 102 insertions(+), 99 deletions(-) diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index ae7dbdf874d..c5b315c6a30 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -57,7 +57,7 @@ impl SpannIndexWriter { } } - pub async fn hnsw_index_from_id( + async fn hnsw_index_from_id( hnsw_provider: &HnswIndexProvider, id: &Uuid, collection_id: &Uuid, @@ -73,7 +73,7 @@ impl SpannIndexWriter { } } - pub async fn create_hnsw_index( + async fn create_hnsw_index( hnsw_provider: &HnswIndexProvider, collection_id: &Uuid, distance_function: DistanceFunction, @@ -96,7 +96,7 @@ impl SpannIndexWriter { } } - pub async fn load_versions_map( + async fn load_versions_map( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, ) -> Result, SpannIndexWriterConstructionError> { @@ -116,7 +116,7 @@ impl SpannIndexWriter { Ok(versions_map) } - pub async fn fork_postings_list( + async fn fork_postings_list( blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, ) -> Result { @@ -129,7 +129,7 @@ impl SpannIndexWriter { } } - pub async fn create_posting_list( + async fn create_posting_list( blockfile_provider: &BlockfileProvider, ) -> Result { match blockfile_provider.create::>() { @@ -137,4 +137,61 @@ impl SpannIndexWriter { Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), } } + + #[allow(clippy::too_many_arguments)] + pub async fn from_id( + hnsw_provider: &HnswIndexProvider, + hnsw_id: Option<&Uuid>, + versions_map_id: Option<&Uuid>, + posting_list_id: Option<&Uuid>, + hnsw_params: Option, + collection_id: &Uuid, + distance_function: DistanceFunction, + dimensionality: usize, + blockfile_provider: &BlockfileProvider, + ) -> Result { + // Create the HNSW index. + let hnsw_index = match hnsw_id { + Some(hnsw_id) => { + Self::hnsw_index_from_id( + hnsw_provider, + hnsw_id, + collection_id, + distance_function, + dimensionality, + ) + .await? + } + None => { + Self::create_hnsw_index( + hnsw_provider, + collection_id, + distance_function, + dimensionality, + hnsw_params.unwrap(), // Safe since caller should always provide this. + ) + .await? + } + }; + // Load the versions map. + let versions_map = match versions_map_id { + Some(versions_map_id) => { + Self::load_versions_map(versions_map_id, blockfile_provider).await? + } + None => HashMap::new(), + }; + // Fork the posting list writer. + let posting_list_writer = match posting_list_id { + Some(posting_list_id) => { + Self::fork_postings_list(posting_list_id, blockfile_provider).await? + } + None => Self::create_posting_list(blockfile_provider).await?, + }; + Ok(Self::new( + hnsw_index, + hnsw_provider.clone(), + posting_list_writer, + versions_map, + )) + } } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 1448d7af698..6e197d87171 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -27,20 +27,14 @@ pub enum SpannSegmentWriterError { DistanceFunctionNotFound, #[error("Hnsw index id parsing error")] IndexIdParsingError, - #[error("HNSW index creation error")] - HnswIndexCreationError, #[error("Hnsw Invalid file path")] HnswInvalidFilePath, #[error("Version map Invalid file path")] VersionMapInvalidFilePath, - #[error("Failure in loading the versions map")] - VersionMapLoadError, - #[error("Failure in forking the posting list")] - PostingListForkError, #[error("Postings list invalid file path")] PostingListInvalidFilePath, - #[error("Posting list creation error")] - PostingListCreationError, + #[error("Spann index creation error")] + SpannIndexWriterConstructionError, } impl ChromaError for SpannSegmentWriterError { @@ -48,14 +42,11 @@ impl ChromaError for SpannSegmentWriterError { match self { Self::InvalidArgument => ErrorCodes::InvalidArgument, Self::IndexIdParsingError => ErrorCodes::Internal, - Self::HnswIndexCreationError => ErrorCodes::Internal, Self::DistanceFunctionNotFound => ErrorCodes::Internal, Self::HnswInvalidFilePath => ErrorCodes::Internal, Self::VersionMapInvalidFilePath => ErrorCodes::Internal, - Self::VersionMapLoadError => ErrorCodes::Internal, - Self::PostingListForkError => ErrorCodes::Internal, Self::PostingListInvalidFilePath => ErrorCodes::Internal, - Self::PostingListCreationError => ErrorCodes::Internal, + Self::SpannIndexWriterConstructionError => ErrorCodes::Internal, } } } @@ -64,14 +55,20 @@ impl SpannSegmentWriter { pub async fn from_segment( segment: &Segment, blockfile_provider: &BlockfileProvider, - hnsw_provider: HnswIndexProvider, + hnsw_provider: &HnswIndexProvider, dimensionality: usize, ) -> Result { if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR { return Err(SpannSegmentWriterError::InvalidArgument); } // Load HNSW index. - let hnsw_index = match segment.file_path.get(HNSW_PATH) { + let distance_function = match distance_function_from_segment(segment) { + Ok(distance_function) => distance_function, + Err(e) => { + return Err(SpannSegmentWriterError::DistanceFunctionNotFound); + } + }; + let (hnsw_id, hnsw_params) = match segment.file_path.get(HNSW_PATH) { Some(hnsw_path) => match hnsw_path.first() { Some(index_id) => { let index_uuid = match Uuid::parse_str(index_id) { @@ -80,62 +77,16 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; - let distance_function = match distance_function_from_segment(segment) { - Ok(distance_function) => distance_function, - Err(e) => { - return Err(SpannSegmentWriterError::DistanceFunctionNotFound); - } - }; - match SpannIndexWriter::hnsw_index_from_id( - &hnsw_provider, - &index_uuid, - &segment.collection, - distance_function, - dimensionality, - ) - .await - { - Ok(index) => index, - Err(_) => { - return Err(SpannSegmentWriterError::HnswIndexCreationError); - } - } + (Some(index_uuid), Some(hnsw_params_from_segment(segment))) } None => { return Err(SpannSegmentWriterError::HnswInvalidFilePath); } }, - // Create a new index. - None => { - let hnsw_params = hnsw_params_from_segment(segment); - - let distance_function = match distance_function_from_segment(segment) { - Ok(distance_function) => distance_function, - Err(e) => { - return Err(SpannSegmentWriterError::DistanceFunctionNotFound); - } - }; - - match SpannIndexWriter::create_hnsw_index( - &hnsw_provider, - &segment.collection, - distance_function, - dimensionality, - hnsw_params, - ) - .await - { - Ok(index) => index, - Err(_) => { - return Err(SpannSegmentWriterError::HnswIndexCreationError); - } - } - } + None => (None, None), }; - // Load version map. Empty if file path is not set. - let mut version_map = HashMap::new(); - if let Some(version_map_path) = segment.file_path.get(VERSION_MAP_PATH) { - version_map = match version_map_path.first() { + let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) { + Some(version_map_path) => match version_map_path.first() { Some(version_map_id) => { let version_map_uuid = match Uuid::parse_str(version_map_id) { Ok(uuid) => uuid, @@ -143,22 +94,16 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; - match SpannIndexWriter::load_versions_map(&version_map_uuid, blockfile_provider) - .await - { - Ok(index) => index, - Err(_) => { - return Err(SpannSegmentWriterError::VersionMapLoadError); - } - } + Some(version_map_uuid) } None => { return Err(SpannSegmentWriterError::VersionMapInvalidFilePath); } - } - } + }, + None => None, + }; // Fork the posting list map. - let posting_list_writer = match segment.file_path.get(POSTING_LIST_PATH) { + let posting_list_id = match segment.file_path.get(POSTING_LIST_PATH) { Some(posting_list_path) => match posting_list_path.first() { Some(posting_list_id) => { let posting_list_uuid = match Uuid::parse_str(posting_list_id) { @@ -167,33 +112,34 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; - match SpannIndexWriter::fork_postings_list( - &posting_list_uuid, - blockfile_provider, - ) - .await - { - Ok(writer) => writer, - Err(_) => { - return Err(SpannSegmentWriterError::PostingListForkError); - } - } + Some(posting_list_uuid) } None => { return Err(SpannSegmentWriterError::PostingListInvalidFilePath); } }, // Create a new index. - None => match SpannIndexWriter::create_posting_list(blockfile_provider).await { - Ok(writer) => writer, - Err(_) => { - return Err(SpannSegmentWriterError::PostingListCreationError); - } - }, + None => None, }; - let index_writer = - SpannIndexWriter::new(hnsw_index, hnsw_provider, posting_list_writer, version_map); + let index_writer = match SpannIndexWriter::from_id( + hnsw_provider, + hnsw_id.as_ref(), + versions_map_id.as_ref(), + posting_list_id.as_ref(), + hnsw_params, + &segment.collection, + distance_function, + dimensionality, + blockfile_provider, + ) + .await + { + Ok(index_writer) => index_writer, + Err(_) => { + return Err(SpannSegmentWriterError::SpannIndexWriterConstructionError); + } + }; Ok(SpannSegmentWriter { index: index_writer, From 8f070a32c6bd89f2f31688def2376ab94ae4a52a Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 6 Nov 2024 14:01:02 -0800 Subject: [PATCH 04/10] Remove outdated comments --- rust/worker/src/segment/spann_segment.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 6e197d87171..2daaf8bd16c 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -61,7 +61,6 @@ impl SpannSegmentWriter { if segment.r#type != SegmentType::Spann || segment.scope != SegmentScope::VECTOR { return Err(SpannSegmentWriterError::InvalidArgument); } - // Load HNSW index. let distance_function = match distance_function_from_segment(segment) { Ok(distance_function) => distance_function, Err(e) => { @@ -102,7 +101,6 @@ impl SpannSegmentWriter { }, None => None, }; - // Fork the posting list map. let posting_list_id = match segment.file_path.get(POSTING_LIST_PATH) { Some(posting_list_path) => match posting_list_path.first() { Some(posting_list_id) => { @@ -118,7 +116,6 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::PostingListInvalidFilePath); } }, - // Create a new index. None => None, }; From fea7e75f7ba648f859f772e56289f5114b933131 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Sun, 10 Nov 2024 22:18:09 -0800 Subject: [PATCH 05/10] Fix build --- rust/index/src/hnsw_provider.rs | 9 +++-- rust/index/src/spann/types.rs | 33 ++++++++++++------- .../src/segment/distributed_hnsw_segment.rs | 2 +- rust/worker/src/segment/spann_segment.rs | 11 ++++--- 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/rust/index/src/hnsw_provider.rs b/rust/index/src/hnsw_provider.rs index 6fae49e2d29..d10ee77eaec 100644 --- a/rust/index/src/hnsw_provider.rs +++ b/rust/index/src/hnsw_provider.rs @@ -1,9 +1,8 @@ +use crate::{HnswIndexConfigError, PersistentIndex}; + use super::config::HnswProviderConfig; -use super::{ - HnswIndex, HnswIndexConfig, HnswIndexFromSegmentError, Index, IndexConfig, - IndexConfigFromSegmentError, IndexUuid, -}; -use crate::PersistentIndex; +use super::{HnswIndex, HnswIndexConfig, Index, IndexConfig, IndexUuid}; + use async_trait::async_trait; use chroma_cache::{Cache, Weighted}; use chroma_config::Configurable; diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index c5b315c6a30..f884ffb30cf 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,14 +1,17 @@ use std::collections::HashMap; use arrow::error; -use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter}; +use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions}; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::SpannPostingList; +use chroma_types::{CollectionUuid, SpannPostingList}; use thiserror::Error; use uuid::Uuid; -use crate::hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}; +use crate::{ + hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}, + IndexUuid, +}; // TODO(Sanket): Add locking structures as necessary. pub struct SpannIndexWriter { @@ -59,8 +62,8 @@ impl SpannIndexWriter { async fn hnsw_index_from_id( hnsw_provider: &HnswIndexProvider, - id: &Uuid, - collection_id: &Uuid, + id: &IndexUuid, + collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, ) -> Result { @@ -75,7 +78,7 @@ impl SpannIndexWriter { async fn create_hnsw_index( hnsw_provider: &HnswIndexProvider, - collection_id: &Uuid, + collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, hnsw_params: HnswIndexParams, @@ -102,7 +105,7 @@ impl SpannIndexWriter { ) -> Result, SpannIndexWriterConstructionError> { // Create a reader for the blockfile. Load all the data into the versions map. let mut versions_map = HashMap::new(); - let reader = match blockfile_provider.open::(blockfile_id).await { + let reader = match blockfile_provider.read::(blockfile_id).await { Ok(reader) => reader, Err(_) => { return Err(SpannIndexWriterConstructionError::BlockfileReaderConstructionError) @@ -120,8 +123,11 @@ impl SpannIndexWriter { blockfile_id: &Uuid, blockfile_provider: &BlockfileProvider, ) -> Result { + let mut bf_options = BlockfileWriterOptions::new(); + bf_options = bf_options.unordered_mutations(); + bf_options = bf_options.fork(blockfile_id.clone()); match blockfile_provider - .fork::>(blockfile_id) + .write::>(bf_options) .await { Ok(writer) => Ok(writer), @@ -132,7 +138,12 @@ impl SpannIndexWriter { async fn create_posting_list( blockfile_provider: &BlockfileProvider, ) -> Result { - match blockfile_provider.create::>() { + let mut bf_options = BlockfileWriterOptions::new(); + bf_options = bf_options.unordered_mutations(); + match blockfile_provider + .write::>(bf_options) + .await + { Ok(writer) => Ok(writer), Err(_) => Err(SpannIndexWriterConstructionError::BlockfileWriterConstructionError), } @@ -141,11 +152,11 @@ impl SpannIndexWriter { #[allow(clippy::too_many_arguments)] pub async fn from_id( hnsw_provider: &HnswIndexProvider, - hnsw_id: Option<&Uuid>, + hnsw_id: Option<&IndexUuid>, versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, hnsw_params: Option, - collection_id: &Uuid, + collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, blockfile_provider: &BlockfileProvider, diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index 936e87d7cb2..77ef086d69b 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -10,7 +10,7 @@ use chroma_index::hnsw_provider::{ HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexProviderOpenError, HnswIndexRef, }; -use chroma_index::{Index, IndexUuid}; +use chroma_index::{HnswIndexConfig, Index, IndexConfig, IndexUuid}; use chroma_index::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; use chroma_types::SegmentUuid; use chroma_types::{get_metadata_value_as, MaterializedLogOperation, MetadataValue, Segment}; diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 2daaf8bd16c..414dcfbf0b2 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use arrow::error; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; -use chroma_types::{Segment, SegmentScope, SegmentType}; +use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter, IndexUuid}; +use chroma_types::{Segment, SegmentScope, SegmentType, SegmentUuid}; use thiserror::Error; use uuid::Uuid; @@ -16,7 +16,7 @@ const POSTING_LIST_PATH: &str = "posting_list_path"; pub(crate) struct SpannSegmentWriter { index: SpannIndexWriter, - id: Uuid, + id: SegmentUuid, } #[derive(Error, Debug)] @@ -76,7 +76,10 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; - (Some(index_uuid), Some(hnsw_params_from_segment(segment))) + ( + Some(IndexUuid(index_uuid)), + Some(hnsw_params_from_segment(segment)), + ) } None => { return Err(SpannSegmentWriterError::HnswInvalidFilePath); From 9b2d34d921ffa37e476c907540bdd0f30478605f Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Sun, 10 Nov 2024 22:56:06 -0800 Subject: [PATCH 06/10] Linter --- rust/blockstore/src/arrow/blockfile.rs | 2 +- rust/blockstore/src/memory/storage.rs | 4 ++-- rust/blockstore/src/types/reader.rs | 2 +- rust/index/src/hnsw.rs | 1 - rust/index/src/lib.rs | 1 + rust/index/src/spann/types.rs | 4 ++-- rust/index/src/types.rs | 5 ++--- rust/worker/src/segment/distributed_hnsw_segment.rs | 7 +++---- rust/worker/src/segment/spann_segment.rs | 10 ++++++---- 9 files changed, 18 insertions(+), 18 deletions(-) diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 2e7c7b4900f..b7d0c76c821 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -699,7 +699,7 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me result.extend(block.get_all_data()); } - return result; + result } } diff --git a/rust/blockstore/src/memory/storage.rs b/rust/blockstore/src/memory/storage.rs index 29e7fb251fd..829740e2dea 100644 --- a/rust/blockstore/src/memory/storage.rs +++ b/rust/blockstore/src/memory/storage.rs @@ -586,11 +586,11 @@ impl Writeable for &DataRecord<'_> { } impl Writeable for &SpannPostingList<'_> { - fn write_to_storage(prefix: &str, key: KeyWrapper, value: Self, storage: &StorageBuilder) { + fn write_to_storage(_: &str, _: KeyWrapper, _: Self, _: &StorageBuilder) { todo!() } - fn remove_from_storage(prefix: &str, key: KeyWrapper, storage: &StorageBuilder) { + fn remove_from_storage(_: &str, _: KeyWrapper, _: &StorageBuilder) { todo!() } } diff --git a/rust/blockstore/src/types/reader.rs b/rust/blockstore/src/types/reader.rs index 926bcad4b9d..7b8e70c2bf0 100644 --- a/rust/blockstore/src/types/reader.rs +++ b/rust/blockstore/src/types/reader.rs @@ -134,7 +134,7 @@ impl< pub async fn get_all_data(&'referred_data self) -> Vec<(&'referred_data str, K, V)> { match self { - BlockfileReader::MemoryBlockfileReader(reader) => todo!(), + BlockfileReader::MemoryBlockfileReader(_) => todo!(), BlockfileReader::ArrowBlockfileReader(reader) => reader.get_all_data().await, } } diff --git a/rust/index/src/hnsw.rs b/rust/index/src/hnsw.rs index f95c19732ae..e67433e4eb1 100644 --- a/rust/index/src/hnsw.rs +++ b/rust/index/src/hnsw.rs @@ -1,6 +1,5 @@ use super::{Index, IndexConfig, IndexUuid, PersistentIndex}; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::MetadataValueConversionError; use std::ffi::CString; use std::ffi::{c_char, c_int}; use std::path::Path; diff --git a/rust/index/src/lib.rs b/rust/index/src/lib.rs index ac9d9c585cd..5c9b21e95a0 100644 --- a/rust/index/src/lib.rs +++ b/rust/index/src/lib.rs @@ -13,6 +13,7 @@ use chroma_cache::new_non_persistent_cache_for_test; use chroma_storage::test_storage; pub use hnsw::*; use hnsw_provider::HnswIndexProvider; +#[allow(unused_imports)] pub use spann::*; use tempfile::tempdir; pub use types::*; diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index f884ffb30cf..c8ab2f76d3f 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use arrow::error; use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions}; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; @@ -14,6 +13,7 @@ use crate::{ }; // TODO(Sanket): Add locking structures as necessary. +#[allow(dead_code)] pub struct SpannIndexWriter { // HNSW index and its provider for centroid search. hnsw_index: HnswIndexRef, @@ -125,7 +125,7 @@ impl SpannIndexWriter { ) -> Result { let mut bf_options = BlockfileWriterOptions::new(); bf_options = bf_options.unordered_mutations(); - bf_options = bf_options.fork(blockfile_id.clone()); + bf_options = bf_options.fork(*blockfile_id); match blockfile_provider .write::>(bf_options) .await diff --git a/rust/index/src/types.rs b/rust/index/src/types.rs index e906ef84376..bcccaed2744 100644 --- a/rust/index/src/types.rs +++ b/rust/index/src/types.rs @@ -1,6 +1,5 @@ -use chroma_distance::{DistanceFunction, DistanceFunctionError}; -use chroma_error::{ChromaError, ErrorCodes}; -use thiserror::Error; +use chroma_distance::DistanceFunction; +use chroma_error::ChromaError; use uuid::Uuid; #[derive(Clone, Debug)] diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index 77ef086d69b..e22807897f9 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -4,16 +4,15 @@ use super::record_segment::ApplyMaterializedLogError; use super::utils::hnsw_params_from_segment; use super::{SegmentFlusher, SegmentWriter}; use async_trait::async_trait; -use chroma_distance::{DistanceFunction, DistanceFunctionError}; +use chroma_distance::DistanceFunctionError; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::{ HnswIndexProvider, HnswIndexProviderCreateError, HnswIndexProviderForkError, HnswIndexProviderOpenError, HnswIndexRef, }; -use chroma_index::{HnswIndexConfig, Index, IndexConfig, IndexUuid}; -use chroma_index::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; +use chroma_index::{Index, IndexUuid}; use chroma_types::SegmentUuid; -use chroma_types::{get_metadata_value_as, MaterializedLogOperation, MetadataValue, Segment}; +use chroma_types::{MaterializedLogOperation, Segment}; use std::collections::HashMap; use std::fmt::Debug; use thiserror::Error; diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 414dcfbf0b2..54c42385de5 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,6 +1,3 @@ -use std::collections::HashMap; - -use arrow::error; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter, IndexUuid}; @@ -10,10 +7,14 @@ use uuid::Uuid; use super::utils::{distance_function_from_segment, hnsw_params_from_segment}; +#[allow(dead_code)] const HNSW_PATH: &str = "hnsw_path"; +#[allow(dead_code)] const VERSION_MAP_PATH: &str = "version_map_path"; +#[allow(dead_code)] const POSTING_LIST_PATH: &str = "posting_list_path"; +#[allow(dead_code)] pub(crate) struct SpannSegmentWriter { index: SpannIndexWriter, id: SegmentUuid, @@ -52,6 +53,7 @@ impl ChromaError for SpannSegmentWriterError { } impl SpannSegmentWriter { + #[allow(dead_code)] pub async fn from_segment( segment: &Segment, blockfile_provider: &BlockfileProvider, @@ -63,7 +65,7 @@ impl SpannSegmentWriter { } let distance_function = match distance_function_from_segment(segment) { Ok(distance_function) => distance_function, - Err(e) => { + Err(_) => { return Err(SpannSegmentWriterError::DistanceFunctionNotFound); } }; From 753f218f17f5109e6f913276c882ae21ab54fd88 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Mon, 11 Nov 2024 13:41:10 -0800 Subject: [PATCH 07/10] Fix build --- rust/index/src/spann/types.rs | 20 ++++++++++++------- rust/worker/src/segment/spann_segment.rs | 13 ++++++++---- rust/worker/src/segment/utils.rs | 25 +++++++++++++----------- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index c8ab2f76d3f..ed5a529e274 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -8,7 +8,7 @@ use thiserror::Error; use uuid::Uuid; use crate::{ - hnsw_provider::{HnswIndexParams, HnswIndexProvider, HnswIndexRef}, + hnsw_provider::{HnswIndexProvider, HnswIndexRef}, IndexUuid, }; @@ -81,14 +81,16 @@ impl SpannIndexWriter { collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, - hnsw_params: HnswIndexParams, + m: usize, + ef_construction: usize, + ef_search: usize, ) -> Result { - let persist_path = &hnsw_provider.temporary_storage_path; match hnsw_provider .create( collection_id, - hnsw_params, - persist_path, + m, + ef_construction, + ef_search, dimensionality as i32, distance_function, ) @@ -155,7 +157,9 @@ impl SpannIndexWriter { hnsw_id: Option<&IndexUuid>, versions_map_id: Option<&Uuid>, posting_list_id: Option<&Uuid>, - hnsw_params: Option, + m: Option, + ef_construction: Option, + ef_search: Option, collection_id: &CollectionUuid, distance_function: DistanceFunction, dimensionality: usize, @@ -179,7 +183,9 @@ impl SpannIndexWriter { collection_id, distance_function, dimensionality, - hnsw_params.unwrap(), // Safe since caller should always provide this. + m.unwrap(), // Safe since caller should always provide this. + ef_construction.unwrap(), // Safe since caller should always provide this. + ef_search.unwrap(), // Safe since caller should always provide this. ) .await? } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 54c42385de5..f15bee2953e 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -69,7 +69,7 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::DistanceFunctionNotFound); } }; - let (hnsw_id, hnsw_params) = match segment.file_path.get(HNSW_PATH) { + let (hnsw_id, m, ef_construction, ef_search) = match segment.file_path.get(HNSW_PATH) { Some(hnsw_path) => match hnsw_path.first() { Some(index_id) => { let index_uuid = match Uuid::parse_str(index_id) { @@ -78,16 +78,19 @@ impl SpannSegmentWriter { return Err(SpannSegmentWriterError::IndexIdParsingError); } }; + let hnsw_params = hnsw_params_from_segment(segment); ( Some(IndexUuid(index_uuid)), - Some(hnsw_params_from_segment(segment)), + Some(hnsw_params.m), + Some(hnsw_params.ef_construction), + Some(hnsw_params.ef_search), ) } None => { return Err(SpannSegmentWriterError::HnswInvalidFilePath); } }, - None => (None, None), + None => (None, None, None, None), }; let versions_map_id = match segment.file_path.get(VERSION_MAP_PATH) { Some(version_map_path) => match version_map_path.first() { @@ -129,7 +132,9 @@ impl SpannSegmentWriter { hnsw_id.as_ref(), versions_map_id.as_ref(), posting_list_id.as_ref(), - hnsw_params, + m, + ef_construction, + ef_search, &segment.collection, distance_function, dimensionality, diff --git a/rust/worker/src/segment/utils.rs b/rust/worker/src/segment/utils.rs index 08b76b17ce6..b5305eecdfb 100644 --- a/rust/worker/src/segment/utils.rs +++ b/rust/worker/src/segment/utils.rs @@ -1,19 +1,18 @@ use chroma_distance::{DistanceFunction, DistanceFunctionError}; -use chroma_index::{ - hnsw_provider::HnswIndexParams, DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, - DEFAULT_HNSW_M, -}; +use chroma_index::{DEFAULT_HNSW_EF_CONSTRUCTION, DEFAULT_HNSW_EF_SEARCH, DEFAULT_HNSW_M}; use chroma_types::{get_metadata_value_as, MetadataValue, Segment}; -pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { +use super::distributed_hnsw_segment::HnswIndexParamsFromSegment; + +pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParamsFromSegment { let metadata = match &segment.metadata { Some(metadata) => metadata, None => { - return ( - DEFAULT_HNSW_M, - DEFAULT_HNSW_EF_CONSTRUCTION, - DEFAULT_HNSW_EF_SEARCH, - ); + return HnswIndexParamsFromSegment { + m: DEFAULT_HNSW_M, + ef_construction: DEFAULT_HNSW_EF_CONSTRUCTION, + ef_search: DEFAULT_HNSW_EF_SEARCH, + }; } }; @@ -30,7 +29,11 @@ pub(super) fn hnsw_params_from_segment(segment: &Segment) -> HnswIndexParams { Err(_) => DEFAULT_HNSW_EF_SEARCH, }; - (m, ef_construction, ef_search) + HnswIndexParamsFromSegment { + m, + ef_construction, + ef_search, + } } pub(crate) fn distance_function_from_segment( From 01839c1e702aa467e606f83abebb49cceb7af146 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 13 Nov 2024 23:04:16 -0800 Subject: [PATCH 08/10] review comments --- rust/blockstore/src/arrow/block/types.rs | 20 -------------------- rust/blockstore/src/arrow/blockfile.rs | 19 ------------------- rust/blockstore/src/arrow/sparse_index.rs | 4 ---- rust/blockstore/src/types/reader.rs | 7 ------- rust/index/src/spann/types.rs | 10 ++++++++-- 5 files changed, 8 insertions(+), 52 deletions(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 34a2f1ec453..6764591c86f 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -346,26 +346,6 @@ impl Block { } } - pub fn get_all_data<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - ) -> Vec<(&'me str, K, V)> { - let prefix_arr = self - .data - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let mut result = Vec::new(); - for i in 0..self.data.num_rows() { - result.push(( - prefix_arr.value(i), - K::get(self.data.column(1), i), - V::get(self.data.column(2), i), - )); - } - result - } - /// Get all the values for a given prefix & key range in the block /// ### Panics /// - If the underlying data types are not the same as the types specified in the function signature diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index b7d0c76c821..44e2db95587 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -682,25 +682,6 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me true } - - pub async fn get_all_data(&'me self) -> Vec<(&'me str, K, V)> { - let block_ids = self.root.sparse_index.get_all_block_ids(); - let mut result = vec![]; - for block_id in block_ids { - let block = match self.get_block(block_id).await { - Ok(Some(block)) => block, - Ok(None) => { - continue; - } - Err(_) => { - continue; - } - }; - - result.extend(block.get_all_data()); - } - result - } } #[cfg(test)] diff --git a/rust/blockstore/src/arrow/sparse_index.rs b/rust/blockstore/src/arrow/sparse_index.rs index 3a7cac2c9ed..3d4124ab447 100644 --- a/rust/blockstore/src/arrow/sparse_index.rs +++ b/rust/blockstore/src/arrow/sparse_index.rs @@ -321,10 +321,6 @@ impl SparseIndexReader { get_target_block(search_key, forward).id } - pub(super) fn get_all_block_ids(&self) -> Vec { - self.data.forward.values().map(|v| v.id).collect() - } - /// Get all the block ids that contain keys in the given input search keys pub(super) fn get_all_target_block_ids(&self, mut search_keys: Vec) -> Vec { // Sort so that we can search in one iteration. diff --git a/rust/blockstore/src/types/reader.rs b/rust/blockstore/src/types/reader.rs index 7b8e70c2bf0..43311d979d4 100644 --- a/rust/blockstore/src/types/reader.rs +++ b/rust/blockstore/src/types/reader.rs @@ -131,11 +131,4 @@ impl< } } } - - pub async fn get_all_data(&'referred_data self) -> Vec<(&'referred_data str, K, V)> { - match self { - BlockfileReader::MemoryBlockfileReader(_) => todo!(), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_all_data().await, - } - } } diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index ed5a529e274..c78c64c36c0 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -33,6 +33,8 @@ pub enum SpannIndexWriterConstructionError { BlockfileReaderConstructionError, #[error("Blockfile writer construction error")] BlockfileWriterConstructionError, + #[error("Error loading version data from blockfile")] + BlockfileVersionDataLoadError, } impl ChromaError for SpannIndexWriterConstructionError { @@ -41,6 +43,7 @@ impl ChromaError for SpannIndexWriterConstructionError { Self::HnswIndexConstructionError => ErrorCodes::Internal, Self::BlockfileReaderConstructionError => ErrorCodes::Internal, Self::BlockfileWriterConstructionError => ErrorCodes::Internal, + Self::BlockfileVersionDataLoadError => ErrorCodes::Internal, } } } @@ -114,8 +117,11 @@ impl SpannIndexWriter { } }; // Load data using the reader. - let versions_data = reader.get_all_data().await; - versions_data.iter().for_each(|(_, key, value)| { + let versions_data = reader + .get_range(.., ..) + .await + .map_err(|_| SpannIndexWriterConstructionError::BlockfileVersionDataLoadError)?; + versions_data.iter().for_each(|(key, value)| { versions_map.insert(*key, *value); }); Ok(versions_map) From 1c000823a84d02403c01ca94a05305d5417508a9 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 27 Nov 2024 16:38:05 -0800 Subject: [PATCH 09/10] Build errors --- rust/worker/src/execution/orchestration/knn.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/rust/worker/src/execution/orchestration/knn.rs b/rust/worker/src/execution/orchestration/knn.rs index e34da719cb9..a9e843be613 100644 --- a/rust/worker/src/execution/orchestration/knn.rs +++ b/rust/worker/src/execution/orchestration/knn.rs @@ -31,9 +31,9 @@ use crate::{ orchestration::common::terminate_with_error, }, segment::distributed_hnsw_segment::{ - distance_function_from_segment, DistributedHNSWSegmentFromSegmentError, - DistributedHNSWSegmentReader, + DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader, }, + segment::utils::distance_function_from_segment, system::{ChannelError, Component, ComponentContext, ComponentHandle, Handler, System}, }; @@ -186,6 +186,8 @@ pub enum KnnError { Panic(String), #[error("Error receiving final result: {0}")] Result(#[from] RecvError), + #[error("Invalid distance function")] + InvalidDistanceFunction, } impl ChromaError for KnnError { @@ -203,6 +205,7 @@ impl ChromaError for KnnError { KnnError::NoCollectionDimension => ErrorCodes::InvalidArgument, KnnError::Panic(_) => ErrorCodes::Aborted, KnnError::Result(_) => ErrorCodes::Internal, + KnnError::InvalidDistanceFunction => ErrorCodes::InvalidArgument, } } } @@ -421,8 +424,8 @@ impl Handler> for KnnFilterOrchestrator { }; let distance_function = match distance_function_from_segment(&segments.vector_segment) { Ok(distance_function) => distance_function, - Err(err) => { - self.terminate_with_error(ctx, *err); + Err(_) => { + self.terminate_with_error(ctx, KnnError::InvalidDistanceFunction); return; } }; From ceb3e43c4f35984c06e35162898152c38e5ac763 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 4 Dec 2024 10:09:19 -0800 Subject: [PATCH 10/10] Fix linter errors --- rust/load/src/config.rs | 2 +- rust/worker/src/segment/metadata_segment.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/load/src/config.rs b/rust/load/src/config.rs index 575fe09a783..8a62558bcf5 100644 --- a/rust/load/src/config.rs +++ b/rust/load/src/config.rs @@ -19,7 +19,7 @@ impl RootConfig { // Unfortunately, figment doesn't support environment variables with underscores. So we have to map and replace them. // Excluding our own environment variables, which are prefixed with CHROMA_. let mut f = figment::Figment::from( - Env::prefixed("CHROMA_").map(|k| return k.as_str().replace("__", ".").into()), + Env::prefixed("CHROMA_").map(|k| k.as_str().replace("__", ".").into()), ); if std::path::Path::new(path).exists() { f = figment::Figment::from(Yaml::file(path)).merge(f); diff --git a/rust/worker/src/segment/metadata_segment.rs b/rust/worker/src/segment/metadata_segment.rs index c5410479cf9..e2f144c85bc 100644 --- a/rust/worker/src/segment/metadata_segment.rs +++ b/rust/worker/src/segment/metadata_segment.rs @@ -1050,7 +1050,7 @@ impl MetadataSegmentReader<'_> { fn process_where_clause<'me>( &'me self, where_clause: &'me Where, - ) -> BoxFuture, MetadataIndexError>> { + ) -> BoxFuture<'me, Result, MetadataIndexError>> { async move { let provider = MetadataProvider::from_metadata_segment_reader(self); let result = where_clause