Skip to content

Commit

Permalink
brute_force_posting_list_operator
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Dec 5, 2024
1 parent 71b6381 commit e6c478a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
1 change: 1 addition & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(super) mod partition;
pub(super) mod pull_log;
pub(super) mod record_segment_prefetch;
pub(super) mod register;
pub(super) mod spann_bf_pl;
pub(super) mod spann_centers_search;
pub(super) mod spann_fetch_pl;
pub(super) mod write_segments;
Expand Down
81 changes: 81 additions & 0 deletions rust/worker/src/execution/operators/spann_bf_pl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_types::SignedRoaringBitmap;
use thiserror::Error;
use tonic::async_trait;

use crate::{
execution::operator::Operator,
segment::spann_segment::{SpannSegmentReader, SpannSegmentReaderContext},
};

use super::knn::RecordDistance;

#[derive(Debug)]
pub struct SpannBfPlInput {
// Needed for checking if a particular version is outdated.
reader_context: SpannSegmentReaderContext,
// Posting list data.
doc_offset_ids: Vec<u32>,
doc_versions: Vec<u32>,
doc_embeddings: Vec<f32>,
// Number of results to return.
k: usize,
// Bitmap of records to include/exclude.
filter: SignedRoaringBitmap,
// Distance function.
distance_function: DistanceFunction,
// Dimension of the embeddings.
dimension: usize,
}

#[derive(Debug)]
pub struct SpannBfPlOutput {
records: Vec<RecordDistance>,
}

#[derive(Error, Debug)]
pub enum SpannBfPlError {
#[error("Error creating spann segment reader")]
SpannSegmentReaderCreationError,
#[error("Error querying reader")]
SpannSegmentReaderError,
}

impl ChromaError for SpannBfPlError {
fn code(&self) -> ErrorCodes {
match self {
Self::SpannSegmentReaderCreationError => ErrorCodes::Internal,
Self::SpannSegmentReaderError => ErrorCodes::Internal,
}
}
}

#[derive(Debug)]
pub struct SpannBfPlOperator {}

impl SpannBfPlOperator {
pub fn new() -> Box<Self> {
Box::new(SpannBfPlOperator {})
}
}

#[async_trait]
impl Operator<SpannBfPlInput, SpannBfPlOutput> for SpannBfPlOperator {
type Error = SpannBfPlError;

async fn run(&self, input: &SpannBfPlInput) -> Result<SpannBfPlOutput, SpannBfPlError> {
let spann_reader = SpannSegmentReader::from_segment(
&input.reader_context.segment,
&input.reader_context.blockfile_provider,
&input.reader_context.hnsw_provider,
input.reader_context.dimension,
)
.await
.map_err(|_| SpannBfPlError::SpannSegmentReaderCreationError)?;

Ok(SpannBfPlOutput {
records: Vec::new(),
})
}
}

0 comments on commit e6c478a

Please sign in to comment.