Skip to content

Commit

Permalink
implement apply_materialized
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Nov 11, 2024
1 parent 8e8f005 commit 22dea40
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
18 changes: 14 additions & 4 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

use arrow::error;
use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriter, BlockfileWriterOptions};
use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_types::{CollectionUuid, SpannPostingList};
use chroma_types::CollectionUuid;
use chroma_types::SpannPostingList;
use parking_lot::RwLock;
use thiserror::Error;
use uuid::Uuid;

Expand All @@ -22,7 +24,8 @@ pub struct SpannIndexWriter {
// The blockfile also contains next id for the head.
posting_list_writer: BlockfileWriter,
// Version number of each point.
versions_map: HashMap<u32, u32>,
// TODO(Sanket): Finer grained locking for this map in future.
versions_map: Arc<RwLock<HashMap<u32, u32>>>,
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -56,7 +59,7 @@ impl SpannIndexWriter {
hnsw_index,
hnsw_provider,
posting_list_writer,
versions_map,
versions_map: Arc::new(RwLock::new(versions_map)),
}
}

Expand Down Expand Up @@ -205,4 +208,11 @@ impl SpannIndexWriter {
versions_map,
))
}

pub fn add_versions_map(&self, id: u32) {
// 0 means deleted. Version counting starts from 1.
self.versions_map.write().insert(id, 1);
}

pub async fn add_new_record_to_postings_list(&self, id: u32, embeddings: &[f32]) {}
}
43 changes: 40 additions & 3 deletions rust/worker/src/segment/spann_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@ use std::collections::HashMap;
use arrow::error;
use chroma_blockstore::provider::BlockfileProvider;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter, IndexUuid};
use chroma_types::{Segment, SegmentScope, SegmentType, SegmentUuid};
use chroma_index::IndexUuid;
use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter};
use chroma_types::SegmentUuid;
use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType};
use thiserror::Error;
use uuid::Uuid;

use super::utils::{distance_function_from_segment, hnsw_params_from_segment};
use super::{
record_segment::ApplyMaterializedLogError,
utils::{distance_function_from_segment, hnsw_params_from_segment},
MaterializedLogRecord, SegmentFlusher, SegmentWriter,
};

const HNSW_PATH: &str = "hnsw_path";
const VERSION_MAP_PATH: &str = "version_map_path";
Expand Down Expand Up @@ -146,4 +152,35 @@ impl SpannSegmentWriter {
id: segment.id,
})
}

async fn add(&self, record: &MaterializedLogRecord<'_>) {
// Initialize the record with a version.
self.index.add_new_record_to_versions_map(record.offset_id);
self.index
.add_new_record_to_postings_list(record.offset_id, record.merged_embeddings());
}
}

impl<'a> SegmentWriter<'a> for SpannSegmentWriter {
async fn apply_materialized_log_chunk(
&self,
records: chroma_types::Chunk<super::MaterializedLogRecord<'a>>,
) -> Result<(), ApplyMaterializedLogError> {
for (record, idx) in records.iter() {
match record.final_operation {
MaterializedLogOperation::AddNew => {
self.add(record).await;
}
// TODO(Sanket): Implement other operations.
_ => {
todo!()
}
}
}
Ok(())
}

async fn commit(self) -> Result<impl SegmentFlusher, Box<dyn ChromaError>> {
todo!()
}
}

0 comments on commit 22dea40

Please sign in to comment.