Skip to content

Commit

Permalink
Benchmark knn orchestrator
Browse files Browse the repository at this point in the history
  • Loading branch information
Sicheng Pan committed Dec 17, 2024
1 parent e015e31 commit 64d0835
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 30 deletions.
4 changes: 2 additions & 2 deletions rust/benchmark/src/datasets/sift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl Sift1MData {
.collect())
}

pub async fn query(&mut self) -> Result<Vec<(Vec<f32>, Vec<usize>)>> {
pub async fn query(&mut self) -> Result<Vec<(Vec<f32>, Vec<u32>)>> {
let mut query_bytes = Vec::new();
self.query.read_to_end(&mut query_bytes).await?;
let (_, embeddings_bytes) = query_bytes.split_at(size_of::<u32>());
Expand All @@ -165,7 +165,7 @@ impl Sift1MData {
let (_, embeddings_bytes) = query_bytes.split_at(size_of::<u32>());
let ground_u32s: Vec<_> = embeddings_bytes
.chunks(size_of::<u32>())
.map(|c| u32::from_le_bytes(c.try_into().unwrap()) as usize)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
Ok(embedding_f32s
.chunks(Self::dimension())
Expand Down
4 changes: 4 additions & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ harness = false
name = "limit"
harness = false

[[bench]]
name = "query"
harness = false

[[bench]]
name = "spann"
harness = false
56 changes: 29 additions & 27 deletions rust/worker/benches/get.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#[allow(dead_code)]
mod load;

use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread};
use chroma_config::Configurable;
use criterion::{criterion_group, criterion_main, Criterion};
use load::{
all_projection, always_true_where_for_modulo_metadata, empty_fetch_log, offset_limit,
all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit,
sift1m_segments, trivial_filter, trivial_limit, trivial_projection,
};
use worker::{
Expand All @@ -23,7 +24,7 @@ fn trivial_get(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
trivial_filter(),
Expand All @@ -41,10 +42,10 @@ fn get_filter(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_where_for_modulo_metadata(),
always_true_filter_for_modulo_metadata(),
trivial_limit(),
trivial_projection(),
)
Expand All @@ -59,10 +60,10 @@ fn get_filter_limit(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_where_for_modulo_metadata(),
always_true_filter_for_modulo_metadata(),
offset_limit(),
trivial_projection(),
)
Expand All @@ -77,15 +78,31 @@ fn get_filter_limit_projection(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_where_for_modulo_metadata(),
always_true_filter_for_modulo_metadata(),
offset_limit(),
all_projection(),
)
}

async fn bench_routine(input: (System, GetOrchestrator, Vec<String>)) {
let (system, orchestrator, expected_ids) = input;
let output = orchestrator
.run(system)
.await
.expect("Orchestrator should not fail");
assert_eq!(
output
.records
.into_iter()
.map(|record| record.id)
.collect::<Vec<_>>(),
expected_ids
);
}

fn bench_get(criterion: &mut Criterion) {
let runtime = tokio_multi_thread();
let test_segments = runtime.block_on(sift1m_segments());
Expand Down Expand Up @@ -128,48 +145,33 @@ fn bench_get(criterion: &mut Criterion) {
)
};

let routine = |(system, orchestrator, expected): (System, GetOrchestrator, Vec<String>)| async move {
let output = orchestrator
.run(system)
.await
.expect("Orchestrator should not fail");
assert_eq!(
output
.records
.into_iter()
.map(|record| record.id)
.collect::<Vec<_>>(),
expected
);
};

bench_run(
"test-trivial-get",
criterion,
&runtime,
trivial_get_setup,
routine,
bench_routine,
);
bench_run(
"test-get-filter",
criterion,
&runtime,
get_filter_setup,
routine,
bench_routine,
);
bench_run(
"test-get-filter-limit",
criterion,
&runtime,
get_filter_limit_setup,
routine,
bench_routine,
);
bench_run(
"test-get-filter-limit-projection",
criterion,
&runtime,
get_filter_limit_projection_setup,
routine,
bench_routine,
);
}
criterion_group!(benches, bench_get);
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/benches/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub fn trivial_filter() -> FilterOperator {
}
}

pub fn always_true_where_for_modulo_metadata() -> FilterOperator {
pub fn always_true_filter_for_modulo_metadata() -> FilterOperator {
FilterOperator {
query_ids: None,
where_clause: Some(Where::conjunction(vec![
Expand Down
218 changes: 218 additions & 0 deletions rust/worker/benches/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#[allow(dead_code)]
mod load;

use std::collections::HashSet;

use chroma_benchmark::{
benchmark::{bench_run, tokio_multi_thread},
datasets::sift::Sift1MData,
};
use chroma_config::Configurable;
use criterion::{criterion_group, criterion_main, Criterion};
use futures::{stream, StreamExt, TryStreamExt};
use load::{
all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments,
trivial_filter,
};
use rand::{seq::SliceRandom, thread_rng};
use worker::{
config::RootConfig,
execution::{
dispatcher::Dispatcher,
operators::{knn::KnnOperator, knn_projection::KnnProjectionOperator},
orchestration::{
knn::KnnOrchestrator,
knn_filter::{KnnFilterOrchestrator, KnnFilterOutput},
},
},
segment::test::TestSegment,
system::{ComponentHandle, System},
};

fn trivial_knn_filter(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> KnnFilterOrchestrator {
let blockfile_provider = test_segments.blockfile_provider.clone();
let hnsw_provider = test_segments.hnsw_provider.clone();
let collection_uuid = test_segments.collection.collection_id;
KnnFilterOrchestrator::new(
blockfile_provider,
dispatcher_handle,
hnsw_provider,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
trivial_filter(),
)
}

fn always_true_knn_filter(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> KnnFilterOrchestrator {
let blockfile_provider = test_segments.blockfile_provider.clone();
let hnsw_provider = test_segments.hnsw_provider.clone();
let collection_uuid = test_segments.collection.collection_id;
KnnFilterOrchestrator::new(
blockfile_provider,
dispatcher_handle,
hnsw_provider,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_filter_for_modulo_metadata(),
)
}

fn knn(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
knn_filter_output: KnnFilterOutput,
query: Vec<f32>,
) -> KnnOrchestrator {
KnnOrchestrator::new(
test_segments.blockfile_provider.clone(),
dispatcher_handle.clone(),
1000,
knn_filter_output.clone(),
KnnOperator {
embedding: query,
fetch: Sift1MData::k() as u32,
},
KnnProjectionOperator {
projection: all_projection(),
distance: true,
},
)
}

async fn bench_routine(
input: (
System,
KnnFilterOrchestrator,
impl Fn(KnnFilterOutput) -> Vec<(KnnOrchestrator, Vec<u32>)>,
),
) {
let (system, knn_filter, knn_constructor) = input;
let knn_filter_output = knn_filter
.run(system.clone())
.await
.expect("Orchestrator should not fail");
let (knns, expected): (Vec<_>, Vec<_>) = knn_constructor(knn_filter_output).into_iter().unzip();
let results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone())))
.buffered(32)
.try_collect::<Vec<_>>()
.await
.expect("Orchestrators should not fail");
results
.into_iter()
.map(|result| {
result
.records
.into_iter()
.map(|record| record.record.id.parse())
// .collect::<Result<_, _>>()
.collect::<Result<Vec<u32>, _>>()
.expect("Record id should be parsable to u32")
})
.zip(expected)
.for_each(|(got, expected)| {
let expected_set: HashSet<_> = HashSet::from_iter(expected);
let recall = got
.into_iter()
.filter(|id| expected_set.contains(id))
.count() as f64
/ expected_set.len() as f64;
assert!(recall > 0.9);
});
}

fn bench_query(criterion: &mut Criterion) {
let runtime = tokio_multi_thread();
let test_segments = runtime.block_on(sift1m_segments());

let config = RootConfig::default();
let system = System::default();
let dispatcher = runtime
.block_on(Dispatcher::try_from_config(
&config.query_service.dispatcher,
))
.expect("Should be able to initialize dispatcher");
let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) });

let mut sift1m = runtime
.block_on(Sift1MData::init())
.expect("Should be able to download Sift1M data");
let mut sift1m_queries = runtime
.block_on(sift1m.query())
.expect("Should be able to load Sift1M queries");

sift1m_queries.as_mut_slice().shuffle(&mut thread_rng());

let trivial_knn_setup = || {
(
system.clone(),
trivial_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()),
|knn_filter_output: KnnFilterOutput| {
sift1m_queries
.iter()
.take(4)
.map(|(query, expected)| {
(
knn(
test_segments.clone(),
dispatcher_handle.clone(),
knn_filter_output.clone(),
query.clone(),
),
expected.clone(),
)
})
.collect()
},
)
};

let filtered_knn_setup = || {
(
system.clone(),
always_true_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()),
|knn_filter_output: KnnFilterOutput| {
sift1m_queries
.iter()
.take(4)
.map(|(query, expected)| {
(
knn(
test_segments.clone(),
dispatcher_handle.clone(),
knn_filter_output.clone(),
query.clone(),
),
expected.clone(),
)
})
.collect()
},
)
};

bench_run(
"test-trivial-knn",
criterion,
&runtime,
trivial_knn_setup,
bench_routine,
);

bench_run(
"test-filtered-knn",
criterion,
&runtime,
filtered_knn_setup,
bench_routine,
);
}
criterion_group!(benches, bench_query);
criterion_main!(benches);

0 comments on commit 64d0835

Please sign in to comment.