Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TST] Benchmark query node #3320

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Benchmark knn orchestrator
  • Loading branch information
Sicheng Pan authored and Sicheng-Pan committed Dec 17, 2024
commit 34121efebd8da2b498d22780f5b15c37a45f84f2
4 changes: 2 additions & 2 deletions rust/benchmark/src/datasets/sift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl Sift1MData {
.collect())
}

pub async fn query(&mut self) -> Result<Vec<(Vec<f32>, Vec<usize>)>> {
pub async fn query(&mut self) -> Result<Vec<(Vec<f32>, Vec<u32>)>> {
let mut query_bytes = Vec::new();
self.query.read_to_end(&mut query_bytes).await?;
let (_, embeddings_bytes) = query_bytes.split_at(size_of::<u32>());
Expand All @@ -165,7 +165,7 @@ impl Sift1MData {
let (_, embeddings_bytes) = query_bytes.split_at(size_of::<u32>());
let ground_u32s: Vec<_> = embeddings_bytes
.chunks(size_of::<u32>())
.map(|c| u32::from_le_bytes(c.try_into().unwrap()) as usize)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
Ok(embedding_f32s
.chunks(Self::dimension())
Expand Down
4 changes: 4 additions & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ harness = false
name = "limit"
harness = false

[[bench]]
name = "query"
harness = false

[[bench]]
name = "spann"
harness = false
56 changes: 29 additions & 27 deletions rust/worker/benches/get.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#[allow(dead_code)]
mod load;

use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread};
use chroma_config::Configurable;
use criterion::{criterion_group, criterion_main, Criterion};
use load::{
all_projection, always_true_where_for_modulo_metadata, empty_fetch_log, offset_limit,
all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit,
sift1m_segments, trivial_filter, trivial_limit, trivial_projection,
};
use worker::{
Expand All @@ -23,7 +24,7 @@ fn trivial_get(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
trivial_filter(),
Expand All @@ -41,10 +42,10 @@ fn get_filter(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_where_for_modulo_metadata(),
always_true_filter_for_modulo_metadata(),
trivial_limit(),
trivial_projection(),
)
Expand All @@ -59,10 +60,10 @@ fn get_filter_limit(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_where_for_modulo_metadata(),
always_true_filter_for_modulo_metadata(),
offset_limit(),
trivial_projection(),
)
Expand All @@ -77,15 +78,31 @@ fn get_filter_limit_projection(
GetOrchestrator::new(
blockfile_provider,
dispatcher_handle,
100,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_where_for_modulo_metadata(),
always_true_filter_for_modulo_metadata(),
offset_limit(),
all_projection(),
)
}

async fn bench_routine(input: (System, GetOrchestrator, Vec<String>)) {
let (system, orchestrator, expected_ids) = input;
let output = orchestrator
.run(system)
.await
.expect("Orchestrator should not fail");
assert_eq!(
output
.records
.into_iter()
.map(|record| record.id)
.collect::<Vec<_>>(),
expected_ids
);
}

fn bench_get(criterion: &mut Criterion) {
let runtime = tokio_multi_thread();
let test_segments = runtime.block_on(sift1m_segments());
Expand Down Expand Up @@ -128,48 +145,33 @@ fn bench_get(criterion: &mut Criterion) {
)
};

let routine = |(system, orchestrator, expected): (System, GetOrchestrator, Vec<String>)| async move {
let output = orchestrator
.run(system)
.await
.expect("Orchestrator should not fail");
assert_eq!(
output
.records
.into_iter()
.map(|record| record.id)
.collect::<Vec<_>>(),
expected
);
};

bench_run(
"test-trivial-get",
criterion,
&runtime,
trivial_get_setup,
routine,
bench_routine,
);
bench_run(
"test-get-filter",
criterion,
&runtime,
get_filter_setup,
routine,
bench_routine,
);
bench_run(
"test-get-filter-limit",
criterion,
&runtime,
get_filter_limit_setup,
routine,
bench_routine,
);
bench_run(
"test-get-filter-limit-projection",
criterion,
&runtime,
get_filter_limit_projection_setup,
routine,
bench_routine,
);
}
criterion_group!(benches, bench_get);
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/benches/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub fn trivial_filter() -> FilterOperator {
}
}

pub fn always_true_where_for_modulo_metadata() -> FilterOperator {
pub fn always_true_filter_for_modulo_metadata() -> FilterOperator {
FilterOperator {
query_ids: None,
where_clause: Some(Where::conjunction(vec![
Expand Down
218 changes: 218 additions & 0 deletions rust/worker/benches/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#[allow(dead_code)]
mod load;

use std::collections::HashSet;

use chroma_benchmark::{
benchmark::{bench_run, tokio_multi_thread},
datasets::sift::Sift1MData,
};
use chroma_config::Configurable;
use criterion::{criterion_group, criterion_main, Criterion};
use futures::{stream, StreamExt, TryStreamExt};
use load::{
all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments,
trivial_filter,
};
use rand::{seq::SliceRandom, thread_rng};
use worker::{
config::RootConfig,
execution::{
dispatcher::Dispatcher,
operators::{knn::KnnOperator, knn_projection::KnnProjectionOperator},
orchestration::{
knn::KnnOrchestrator,
knn_filter::{KnnFilterOrchestrator, KnnFilterOutput},
},
},
segment::test::TestSegment,
system::{ComponentHandle, System},
};

fn trivial_knn_filter(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> KnnFilterOrchestrator {
let blockfile_provider = test_segments.blockfile_provider.clone();
let hnsw_provider = test_segments.hnsw_provider.clone();
let collection_uuid = test_segments.collection.collection_id;
KnnFilterOrchestrator::new(
blockfile_provider,
dispatcher_handle,
hnsw_provider,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
trivial_filter(),
)
}

fn always_true_knn_filter(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
) -> KnnFilterOrchestrator {
let blockfile_provider = test_segments.blockfile_provider.clone();
let hnsw_provider = test_segments.hnsw_provider.clone();
let collection_uuid = test_segments.collection.collection_id;
KnnFilterOrchestrator::new(
blockfile_provider,
dispatcher_handle,
hnsw_provider,
1000,
test_segments.into(),
empty_fetch_log(collection_uuid),
always_true_filter_for_modulo_metadata(),
)
}

fn knn(
test_segments: TestSegment,
dispatcher_handle: ComponentHandle<Dispatcher>,
knn_filter_output: KnnFilterOutput,
query: Vec<f32>,
) -> KnnOrchestrator {
KnnOrchestrator::new(
test_segments.blockfile_provider.clone(),
dispatcher_handle.clone(),
1000,
knn_filter_output.clone(),
KnnOperator {
embedding: query,
fetch: Sift1MData::k() as u32,
},
KnnProjectionOperator {
projection: all_projection(),
distance: true,
},
)
}

async fn bench_routine(
input: (
System,
KnnFilterOrchestrator,
impl Fn(KnnFilterOutput) -> Vec<(KnnOrchestrator, Vec<u32>)>,
),
) {
let (system, knn_filter, knn_constructor) = input;
let knn_filter_output = knn_filter
.run(system.clone())
.await
.expect("Orchestrator should not fail");
let (knns, expected): (Vec<_>, Vec<_>) = knn_constructor(knn_filter_output).into_iter().unzip();
let results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone())))
.buffered(32)
.try_collect::<Vec<_>>()
.await
.expect("Orchestrators should not fail");
results
.into_iter()
.map(|result| {
result
.records
.into_iter()
.map(|record| record.record.id.parse())
// .collect::<Result<_, _>>()
.collect::<Result<Vec<u32>, _>>()
.expect("Record id should be parsable to u32")
})
.zip(expected)
.for_each(|(got, expected)| {
let expected_set: HashSet<_> = HashSet::from_iter(expected);
let recall = got
.into_iter()
.filter(|id| expected_set.contains(id))
.count() as f64
/ expected_set.len() as f64;
assert!(recall > 0.9);
});
}

fn bench_query(criterion: &mut Criterion) {
let runtime = tokio_multi_thread();
let test_segments = runtime.block_on(sift1m_segments());

let config = RootConfig::default();
let system = System::default();
let dispatcher = runtime
.block_on(Dispatcher::try_from_config(
&config.query_service.dispatcher,
))
.expect("Should be able to initialize dispatcher");
let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) });

let mut sift1m = runtime
.block_on(Sift1MData::init())
.expect("Should be able to download Sift1M data");
let mut sift1m_queries = runtime
.block_on(sift1m.query())
.expect("Should be able to load Sift1M queries");

sift1m_queries.as_mut_slice().shuffle(&mut thread_rng());

let trivial_knn_setup = || {
(
system.clone(),
trivial_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()),
|knn_filter_output: KnnFilterOutput| {
sift1m_queries
.iter()
.take(4)
.map(|(query, expected)| {
(
knn(
test_segments.clone(),
dispatcher_handle.clone(),
knn_filter_output.clone(),
query.clone(),
),
expected.clone(),
)
})
.collect()
},
)
};

let filtered_knn_setup = || {
(
system.clone(),
always_true_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()),
|knn_filter_output: KnnFilterOutput| {
sift1m_queries
.iter()
.take(4)
.map(|(query, expected)| {
(
knn(
test_segments.clone(),
dispatcher_handle.clone(),
knn_filter_output.clone(),
query.clone(),
),
expected.clone(),
)
})
.collect()
},
)
};

bench_run(
"test-trivial-knn",
criterion,
&runtime,
trivial_knn_setup,
bench_routine,
);

bench_run(
"test-filtered-knn",
criterion,
&runtime,
filtered_knn_setup,
bench_routine,
);
}
criterion_group!(benches, bench_query);
criterion_main!(benches);