Skip to content

Commit

Permalink
[ENH] Brute force distances for a posting list operator (#3215)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
   - None
 - New functionality
   -  Goes over the posting list, filters out the allowed set, sorts and returns the top k.

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia authored Dec 5, 2024
1 parent 37a6103 commit 46acfc3
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
1 change: 1 addition & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(super) mod partition;
pub(super) mod pull_log;
pub(super) mod record_segment_prefetch;
pub(super) mod register;
pub(super) mod spann_bf_pl;
pub(super) mod spann_centers_search;
pub(super) mod spann_fetch_pl;
pub(super) mod write_segments;
Expand Down
88 changes: 88 additions & 0 deletions rust/worker/src/execution/operators/spann_bf_pl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::collections::BinaryHeap;

use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::spann::types::SpannPosting;
use chroma_types::SignedRoaringBitmap;
use thiserror::Error;
use tonic::async_trait;

use crate::execution::operator::Operator;

use super::knn::RecordDistance;

#[derive(Debug)]
pub struct SpannBfPlInput {
// TODO(Sanket): We might benefit from a flat structure which might be more cache friendly.
// Posting list data.
posting_list: Vec<SpannPosting>,
// Number of results to return.
k: usize,
// Bitmap of records to include/exclude.
filter: SignedRoaringBitmap,
// Distance function.
distance_function: DistanceFunction,
// Query embedding.
query: Vec<f32>,
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct SpannBfPlOutput {
records: Vec<RecordDistance>,
}

#[derive(Error, Debug)]
pub enum SpannBfPlError {}

impl ChromaError for SpannBfPlError {
fn code(&self) -> ErrorCodes {
ErrorCodes::Internal
}
}

#[derive(Debug)]
pub struct SpannBfPlOperator {}

#[allow(dead_code)]
impl SpannBfPlOperator {
pub fn new() -> Box<Self> {
Box::new(SpannBfPlOperator {})
}
}

#[async_trait]
impl Operator<SpannBfPlInput, SpannBfPlOutput> for SpannBfPlOperator {
type Error = SpannBfPlError;

async fn run(&self, input: &SpannBfPlInput) -> Result<SpannBfPlOutput, SpannBfPlError> {
let mut max_heap = BinaryHeap::with_capacity(input.k);
for posting in input.posting_list.iter() {
let skip_entry = match &input.filter {
SignedRoaringBitmap::Include(rbm) => !rbm.contains(posting.doc_offset_id),
SignedRoaringBitmap::Exclude(rbm) => rbm.contains(posting.doc_offset_id),
};
if skip_entry {
continue;
}
let dist = input
.distance_function
.distance(&posting.doc_embedding, &input.query);
let record = RecordDistance {
offset_id: posting.doc_offset_id,
measure: dist,
};
if max_heap.len() < input.k {
max_heap.push(record);
} else if let Some(furthest_distance) = max_heap.peek() {
if &record < furthest_distance {
max_heap.pop();
max_heap.push(record);
}
}
}
Ok(SpannBfPlOutput {
records: max_heap.into_sorted_vec(),
})
}
}

0 comments on commit 46acfc3

Please sign in to comment.