diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index 5c17fe87c89..918b8d4953a 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -9,6 +9,7 @@ pub(super) mod partition; pub(super) mod pull_log; pub(super) mod record_segment_prefetch; pub(super) mod register; +pub(super) mod spann_bf_pl; pub(super) mod spann_centers_search; pub(super) mod spann_fetch_pl; pub(super) mod write_segments; diff --git a/rust/worker/src/execution/operators/spann_bf_pl.rs b/rust/worker/src/execution/operators/spann_bf_pl.rs new file mode 100644 index 00000000000..6a2f997e23f --- /dev/null +++ b/rust/worker/src/execution/operators/spann_bf_pl.rs @@ -0,0 +1,88 @@ +use std::collections::BinaryHeap; + +use chroma_distance::DistanceFunction; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_index::spann::types::SpannPosting; +use chroma_types::SignedRoaringBitmap; +use thiserror::Error; +use tonic::async_trait; + +use crate::execution::operator::Operator; + +use super::knn::RecordDistance; + +#[derive(Debug)] +pub struct SpannBfPlInput { + // TODO(Sanket): We might benefit from a flat structure which might be more cache friendly. + // Posting list data. + posting_list: Vec, + // Number of results to return. + k: usize, + // Bitmap of records to include/exclude. + filter: SignedRoaringBitmap, + // Distance function. + distance_function: DistanceFunction, + // Query embedding. + query: Vec, +} + +#[allow(dead_code)] +#[derive(Debug)] +pub struct SpannBfPlOutput { + records: Vec, +} + +#[derive(Error, Debug)] +pub enum SpannBfPlError {} + +impl ChromaError for SpannBfPlError { + fn code(&self) -> ErrorCodes { + ErrorCodes::Internal + } +} + +#[derive(Debug)] +pub struct SpannBfPlOperator {} + +#[allow(dead_code)] +impl SpannBfPlOperator { + pub fn new() -> Box { + Box::new(SpannBfPlOperator {}) + } +} + +#[async_trait] +impl Operator for SpannBfPlOperator { + type Error = SpannBfPlError; + + async fn run(&self, input: &SpannBfPlInput) -> Result { + let mut max_heap = BinaryHeap::with_capacity(input.k); + for posting in input.posting_list.iter() { + let skip_entry = match &input.filter { + SignedRoaringBitmap::Include(rbm) => !rbm.contains(posting.doc_offset_id), + SignedRoaringBitmap::Exclude(rbm) => rbm.contains(posting.doc_offset_id), + }; + if skip_entry { + continue; + } + let dist = input + .distance_function + .distance(&posting.doc_embedding, &input.query); + let record = RecordDistance { + offset_id: posting.doc_offset_id, + measure: dist, + }; + if max_heap.len() < input.k { + max_heap.push(record); + } else if let Some(furthest_distance) = max_heap.peek() { + if &record < furthest_distance { + max_heap.pop(); + max_heap.push(record); + } + } + } + Ok(SpannBfPlOutput { + records: max_heap.into_sorted_vec(), + }) + } +}