Skip to content

Commit

Permalink
Cleanup + review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Dec 4, 2024
1 parent c886068 commit 5e097f4
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 120 deletions.
228 changes: 135 additions & 93 deletions rust/index/src/spann/types.rs

Large diffs are not rendered by default.

11 changes: 0 additions & 11 deletions rust/index/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ pub(super) fn generate_random_data(n: usize, d: usize) -> Vec<f32> {
data
}

const EPS: f32 = 1e-30;

pub fn merge_sorted_vecs_disjunction<T: Ord + Clone>(a: &[T], b: &[T]) -> Vec<T> {
let mut result = Vec::with_capacity(a.len() + b.len());
let mut a_idx = 0;
Expand Down Expand Up @@ -79,15 +77,6 @@ pub fn merge_sorted_vecs_conjunction<T: Ord + Clone>(a: &[T], b: &[T]) -> Vec<T>
result
}

pub fn normalize(vector: &[f32]) -> Vec<f32> {
let mut norm = 0.0;
for x in vector {
norm += x * x;
}
let norm = 1.0 / (norm.sqrt() + EPS);
vector.iter().map(|x| x * norm).collect()
}

#[cfg(test)]
mod tests {
#[test]
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/operators/brute_force_knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::segment::record_segment::RecordSegmentReader;
use crate::segment::{materialize_logs, LogMaterializerError};
use async_trait::async_trait;
use chroma_blockstore::provider::BlockfileProvider;
use chroma_distance::normalize;
use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::utils::normalize;
use chroma_types::Chunk;
use chroma_types::{LogRecord, MaterializedLogOperation, Segment};
use std::cmp::Ordering;
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/operators/normalize_vectors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::execution::operator::Operator;
use async_trait::async_trait;
use chroma_index::utils::normalize;
use chroma_distance::normalize;

#[derive(Debug)]
pub struct NormalizeVectorOperator {}
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ use crate::{
};
use async_trait::async_trait;
use chroma_blockstore::provider::BlockfileProvider;
use chroma_distance::normalize;
use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::hnsw_provider::HnswIndexProvider;
use chroma_index::utils::normalize;
use chroma_index::IndexConfig;
use chroma_types::{Chunk, Collection, CollectionUuid, LogRecord, Segment, VectorQueryResult};
use std::collections::HashMap;
Expand Down
4 changes: 4 additions & 0 deletions rust/worker/src/segment/record_segment.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::spann_segment::SpannSegmentWriterError;
use super::types::{MaterializedLogRecord, SegmentWriter};
use super::SegmentFlusher;
use async_trait::async_trait;
Expand Down Expand Up @@ -322,6 +323,8 @@ pub enum ApplyMaterializedLogError {
FullTextIndex(#[from] FullTextIndexError),
#[error("Error writing to hnsw index")]
HnswIndex(#[from] Box<dyn ChromaError>),
#[error("Error applying materialized records to spann segment: {0}")]
SpannSegmentError(#[from] SpannSegmentWriterError),
}

impl ChromaError for ApplyMaterializedLogError {
Expand All @@ -333,6 +336,7 @@ impl ChromaError for ApplyMaterializedLogError {
ApplyMaterializedLogError::Allocation => ErrorCodes::Internal,
ApplyMaterializedLogError::FullTextIndex(e) => e.code(),
ApplyMaterializedLogError::HnswIndex(_) => ErrorCodes::Internal,
ApplyMaterializedLogError::SpannSegmentError(e) => e.code(),
}
}
}
Expand Down
28 changes: 15 additions & 13 deletions rust/worker/src/segment/spann_segment.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::HashMap;

use chroma_blockstore::provider::BlockfileProvider;
use chroma_distance::DistanceFunctionError;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::spann::types::SpannIndexFlusher;
use chroma_index::spann::types::{SpannIndexFlusher, SpannIndexWriterError};
use chroma_index::IndexUuid;
use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter};
use chroma_types::SegmentUuid;
Expand All @@ -28,12 +29,13 @@ pub(crate) struct SpannSegmentWriter {
id: SegmentUuid,
}

// TODO(Sanket): Better error composability here.
#[derive(Error, Debug)]
pub enum SpannSegmentWriterError {
#[error("Invalid argument")]
InvalidArgument,
#[error("Segment metadata does not contain distance function")]
DistanceFunctionNotFound,
#[error("Segment metadata does not contain distance function {0}")]
DistanceFunctionNotFound(#[from] DistanceFunctionError),
#[error("Error parsing index uuid from string")]
IndexIdParsingError,
#[error("Invalid file path for HNSW index")]
Expand All @@ -46,8 +48,8 @@ pub enum SpannSegmentWriterError {
MaxHeadIdInvalidFilePath,
#[error("Error constructing spann index writer")]
SpannSegmentWriterCreateError,
#[error("Error adding record to spann index writer")]
SpannSegmentWriterAddRecordError,
#[error("Error adding record to spann index writer {0}")]
SpannSegmentWriterAddRecordError(#[from] SpannIndexWriterError),
#[error("Error committing spann index writer")]
SpannSegmentWriterCommitError,
#[error("Error flushing spann index writer")]
Expand All @@ -61,7 +63,7 @@ impl ChromaError for SpannSegmentWriterError {
match self {
Self::InvalidArgument => ErrorCodes::InvalidArgument,
Self::IndexIdParsingError => ErrorCodes::Internal,
Self::DistanceFunctionNotFound => ErrorCodes::Internal,
Self::DistanceFunctionNotFound(e) => e.code(),
Self::HnswInvalidFilePath => ErrorCodes::Internal,
Self::VersionMapInvalidFilePath => ErrorCodes::Internal,
Self::PostingListInvalidFilePath => ErrorCodes::Internal,
Expand All @@ -70,7 +72,7 @@ impl ChromaError for SpannSegmentWriterError {
Self::NotImplemented => ErrorCodes::Internal,
Self::SpannSegmentWriterCommitError => ErrorCodes::Internal,
Self::SpannSegmentWriterFlushError => ErrorCodes::Internal,
Self::SpannSegmentWriterAddRecordError => ErrorCodes::Internal,
Self::SpannSegmentWriterAddRecordError(e) => e.code(),
}
}
}
Expand All @@ -88,8 +90,8 @@ impl SpannSegmentWriter {
}
let distance_function = match distance_function_from_segment(segment) {
Ok(distance_function) => distance_function,
Err(_) => {
return Err(SpannSegmentWriterError::DistanceFunctionNotFound);
Err(e) => {
return Err(SpannSegmentWriterError::DistanceFunctionNotFound(*e));
}
};
let (hnsw_id, m, ef_construction, ef_search) = match segment.file_path.get(HNSW_PATH) {
Expand Down Expand Up @@ -202,25 +204,25 @@ impl SpannSegmentWriter {
self.index
.add(record.offset_id, record.merged_embeddings())
.await
.map_err(|_| SpannSegmentWriterError::SpannSegmentWriterAddRecordError)
.map_err(SpannSegmentWriterError::SpannSegmentWriterAddRecordError)
}
}

struct SpannSegmentFlusher {
index_flusher: SpannIndexFlusher,
}

impl<'a> SegmentWriter<'a> for SpannSegmentWriter {
impl<'referred_data> SegmentWriter<'referred_data> for SpannSegmentWriter {
async fn apply_materialized_log_chunk(
&self,
records: chroma_types::Chunk<super::MaterializedLogRecord<'a>>,
records: chroma_types::Chunk<super::MaterializedLogRecord<'referred_data>>,
) -> Result<(), ApplyMaterializedLogError> {
for (record, _) in records.iter() {
match record.final_operation {
MaterializedLogOperation::AddNew => {
self.add(record)
.await
.map_err(|_| ApplyMaterializedLogError::BlockfileSet)?;
.map_err(ApplyMaterializedLogError::SpannSegmentError)?;
}
// TODO(Sanket): Implement other operations.
_ => {
Expand Down

0 comments on commit 5e097f4

Please sign in to comment.