diff --git a/rust/worker/src/execution/operators/spann_bf_pl.rs b/rust/worker/src/execution/operators/spann_bf_pl.rs index 780d2dbfaddc..5fea57e0235f 100644 --- a/rust/worker/src/execution/operators/spann_bf_pl.rs +++ b/rust/worker/src/execution/operators/spann_bf_pl.rs @@ -1,32 +1,28 @@ +use std::collections::BinaryHeap; + use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; +use chroma_index::spann::types::SpannPosting; use chroma_types::SignedRoaringBitmap; use thiserror::Error; use tonic::async_trait; -use crate::{ - execution::operator::Operator, - segment::spann_segment::{SpannSegmentReader, SpannSegmentReaderContext}, -}; +use crate::execution::operator::Operator; use super::knn::RecordDistance; #[derive(Debug)] pub struct SpannBfPlInput { - // Needed for checking if a particular version is outdated. - reader_context: SpannSegmentReaderContext, // Posting list data. - doc_offset_ids: Vec, - doc_versions: Vec, - doc_embeddings: Vec, + posting_list: Vec, // Number of results to return. k: usize, // Bitmap of records to include/exclude. filter: SignedRoaringBitmap, // Distance function. distance_function: DistanceFunction, - // Dimension of the embeddings. - dimension: usize, + // Query embedding. + query: Vec, } #[derive(Debug)] @@ -35,19 +31,11 @@ pub struct SpannBfPlOutput { } #[derive(Error, Debug)] -pub enum SpannBfPlError { - #[error("Error creating spann segment reader")] - SpannSegmentReaderCreationError, - #[error("Error querying reader")] - SpannSegmentReaderError, -} +pub enum SpannBfPlError {} impl ChromaError for SpannBfPlError { fn code(&self) -> ErrorCodes { - match self { - Self::SpannSegmentReaderCreationError => ErrorCodes::Internal, - Self::SpannSegmentReaderError => ErrorCodes::Internal, - } + ErrorCodes::Internal } } @@ -65,17 +53,33 @@ impl Operator for SpannBfPlOperator { type Error = SpannBfPlError; async fn run(&self, input: &SpannBfPlInput) -> Result { - let spann_reader = SpannSegmentReader::from_segment( - &input.reader_context.segment, - &input.reader_context.blockfile_provider, - &input.reader_context.hnsw_provider, - input.reader_context.dimension, - ) - .await - .map_err(|_| SpannBfPlError::SpannSegmentReaderCreationError)?; - + let mut max_heap = BinaryHeap::with_capacity(input.k); + for posting in input.posting_list.iter() { + let skip_entry = match &input.filter { + SignedRoaringBitmap::Include(rbm) => !rbm.contains(posting.doc_offset_id), + SignedRoaringBitmap::Exclude(rbm) => rbm.contains(posting.doc_offset_id), + }; + if skip_entry { + continue; + } + let dist = input + .distance_function + .distance(&posting.doc_embedding, &input.query); + let record = RecordDistance { + offset_id: posting.doc_offset_id, + measure: dist, + }; + if max_heap.len() < input.k { + max_heap.push(record); + } else if let Some(furthest_distance) = max_heap.peek() { + if &record < furthest_distance { + max_heap.pop(); + max_heap.push(record); + } + } + } Ok(SpannBfPlOutput { - records: Vec::new(), + records: max_heap.into_sorted_vec(), }) } }