diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index 5c17fe87c89a..918b8d4953a7 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -9,6 +9,7 @@ pub(super) mod partition; pub(super) mod pull_log; pub(super) mod record_segment_prefetch; pub(super) mod register; +pub(super) mod spann_bf_pl; pub(super) mod spann_centers_search; pub(super) mod spann_fetch_pl; pub(super) mod write_segments; diff --git a/rust/worker/src/execution/operators/spann_bf_pl.rs b/rust/worker/src/execution/operators/spann_bf_pl.rs new file mode 100644 index 000000000000..780d2dbfaddc --- /dev/null +++ b/rust/worker/src/execution/operators/spann_bf_pl.rs @@ -0,0 +1,81 @@ +use chroma_distance::DistanceFunction; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_types::SignedRoaringBitmap; +use thiserror::Error; +use tonic::async_trait; + +use crate::{ + execution::operator::Operator, + segment::spann_segment::{SpannSegmentReader, SpannSegmentReaderContext}, +}; + +use super::knn::RecordDistance; + +#[derive(Debug)] +pub struct SpannBfPlInput { + // Needed for checking if a particular version is outdated. + reader_context: SpannSegmentReaderContext, + // Posting list data. + doc_offset_ids: Vec, + doc_versions: Vec, + doc_embeddings: Vec, + // Number of results to return. + k: usize, + // Bitmap of records to include/exclude. + filter: SignedRoaringBitmap, + // Distance function. + distance_function: DistanceFunction, + // Dimension of the embeddings. + dimension: usize, +} + +#[derive(Debug)] +pub struct SpannBfPlOutput { + records: Vec, +} + +#[derive(Error, Debug)] +pub enum SpannBfPlError { + #[error("Error creating spann segment reader")] + SpannSegmentReaderCreationError, + #[error("Error querying reader")] + SpannSegmentReaderError, +} + +impl ChromaError for SpannBfPlError { + fn code(&self) -> ErrorCodes { + match self { + Self::SpannSegmentReaderCreationError => ErrorCodes::Internal, + Self::SpannSegmentReaderError => ErrorCodes::Internal, + } + } +} + +#[derive(Debug)] +pub struct SpannBfPlOperator {} + +impl SpannBfPlOperator { + pub fn new() -> Box { + Box::new(SpannBfPlOperator {}) + } +} + +#[async_trait] +impl Operator for SpannBfPlOperator { + type Error = SpannBfPlError; + + async fn run(&self, input: &SpannBfPlInput) -> Result { + let spann_reader = SpannSegmentReader::from_segment( + &input.reader_context.segment, + &input.reader_context.blockfile_provider, + &input.reader_context.hnsw_provider, + input.reader_context.dimension, + ) + .await + .map_err(|_| SpannBfPlError::SpannSegmentReaderCreationError)?; + + Ok(SpannBfPlOutput { + records: Vec::new(), + }) + } +}