From 22dea404ad7dbf48dc6a9e4b4483330bbf15f6e3 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Wed, 6 Nov 2024 13:13:39 -0800 Subject: [PATCH] implement apply_materialized --- rust/index/src/spann/types.rs | 18 +++++++--- rust/worker/src/segment/spann_segment.rs | 43 ++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index f884ffb30cf5..ce1f3165f5e8 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -1,10 +1,12 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use arrow::error; use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions}; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{CollectionUuid, SpannPostingList}; +use chroma_types::CollectionUuid; +use chroma_types::SpannPostingList; +use parking_lot::RwLock; use thiserror::Error; use uuid::Uuid; @@ -22,7 +24,8 @@ pub struct SpannIndexWriter { // The blockfile also contains next id for the head. posting_list_writer: BlockfileWriter, // Version number of each point. - versions_map: HashMap, + // TODO(Sanket): Finer grained locking for this map in future. + versions_map: Arc>>, } #[derive(Error, Debug)] @@ -56,7 +59,7 @@ impl SpannIndexWriter { hnsw_index, hnsw_provider, posting_list_writer, - versions_map, + versions_map: Arc::new(RwLock::new(versions_map)), } } @@ -205,4 +208,11 @@ impl SpannIndexWriter { versions_map, )) } + + pub fn add_versions_map(&self, id: u32) { + // 0 means deleted. Version counting starts from 1. + self.versions_map.write().insert(id, 1); + } + + pub async fn add_new_record_to_postings_list(&self, id: u32, embeddings: &[f32]) {} } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index 414dcfbf0b23..e018e3e85461 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -3,12 +3,18 @@ use std::collections::HashMap; use arrow::error; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter, IndexUuid}; -use chroma_types::{Segment, SegmentScope, SegmentType, SegmentUuid}; +use chroma_index::IndexUuid; +use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; +use chroma_types::SegmentUuid; +use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; use uuid::Uuid; -use super::utils::{distance_function_from_segment, hnsw_params_from_segment}; +use super::{ + record_segment::ApplyMaterializedLogError, + utils::{distance_function_from_segment, hnsw_params_from_segment}, + MaterializedLogRecord, SegmentFlusher, SegmentWriter, +}; const HNSW_PATH: &str = "hnsw_path"; const VERSION_MAP_PATH: &str = "version_map_path"; @@ -146,4 +152,35 @@ impl SpannSegmentWriter { id: segment.id, }) } + + async fn add(&self, record: &MaterializedLogRecord<'_>) { + // Initialize the record with a version. + self.index.add_new_record_to_versions_map(record.offset_id); + self.index + .add_new_record_to_postings_list(record.offset_id, record.merged_embeddings()); + } +} + +impl<'a> SegmentWriter<'a> for SpannSegmentWriter { + async fn apply_materialized_log_chunk( + &self, + records: chroma_types::Chunk>, + ) -> Result<(), ApplyMaterializedLogError> { + for (record, idx) in records.iter() { + match record.final_operation { + MaterializedLogOperation::AddNew => { + self.add(record).await; + } + // TODO(Sanket): Implement other operations. + _ => { + todo!() + } + } + } + Ok(()) + } + + async fn commit(self) -> Result> { + todo!() + } }