From 64d0835ccb57eef84defd456f9be21a80e0cc122 Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Mon, 16 Dec 2024 16:49:36 -0800 Subject: [PATCH] Benchmark knn orchestrator --- rust/benchmark/src/datasets/sift.rs | 4 +- rust/worker/Cargo.toml | 4 + rust/worker/benches/get.rs | 56 +++---- rust/worker/benches/load.rs | 2 +- rust/worker/benches/query.rs | 218 ++++++++++++++++++++++++++++ 5 files changed, 254 insertions(+), 30 deletions(-) create mode 100644 rust/worker/benches/query.rs diff --git a/rust/benchmark/src/datasets/sift.rs b/rust/benchmark/src/datasets/sift.rs index d7cb7bcaa4ff..d18a6cdcd8a6 100644 --- a/rust/benchmark/src/datasets/sift.rs +++ b/rust/benchmark/src/datasets/sift.rs @@ -151,7 +151,7 @@ impl Sift1MData { .collect()) } - pub async fn query(&mut self) -> Result, Vec)>> { + pub async fn query(&mut self) -> Result, Vec)>> { let mut query_bytes = Vec::new(); self.query.read_to_end(&mut query_bytes).await?; let (_, embeddings_bytes) = query_bytes.split_at(size_of::()); @@ -165,7 +165,7 @@ impl Sift1MData { let (_, embeddings_bytes) = query_bytes.split_at(size_of::()); let ground_u32s: Vec<_> = embeddings_bytes .chunks(size_of::()) - .map(|c| u32::from_le_bytes(c.try_into().unwrap()) as usize) + .map(|c| u32::from_le_bytes(c.try_into().unwrap())) .collect(); Ok(embedding_f32s .chunks(Self::dimension()) diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index 9ac85f39395c..800093d306bd 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -84,6 +84,10 @@ harness = false name = "limit" harness = false +[[bench]] +name = "query" +harness = false + [[bench]] name = "spann" harness = false diff --git a/rust/worker/benches/get.rs b/rust/worker/benches/get.rs index 85a9c83db814..281318983097 100644 --- a/rust/worker/benches/get.rs +++ b/rust/worker/benches/get.rs @@ -1,10 +1,11 @@ +#[allow(dead_code)] mod load; use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread}; use chroma_config::Configurable; use criterion::{criterion_group, criterion_main, Criterion}; use load::{ - all_projection, always_true_where_for_modulo_metadata, empty_fetch_log, offset_limit, + all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, sift1m_segments, trivial_filter, trivial_limit, trivial_projection, }; use worker::{ @@ -23,7 +24,7 @@ fn trivial_get( GetOrchestrator::new( blockfile_provider, dispatcher_handle, - 100, + 1000, test_segments.into(), empty_fetch_log(collection_uuid), trivial_filter(), @@ -41,10 +42,10 @@ fn get_filter( GetOrchestrator::new( blockfile_provider, dispatcher_handle, - 100, + 1000, test_segments.into(), empty_fetch_log(collection_uuid), - always_true_where_for_modulo_metadata(), + always_true_filter_for_modulo_metadata(), trivial_limit(), trivial_projection(), ) @@ -59,10 +60,10 @@ fn get_filter_limit( GetOrchestrator::new( blockfile_provider, dispatcher_handle, - 100, + 1000, test_segments.into(), empty_fetch_log(collection_uuid), - always_true_where_for_modulo_metadata(), + always_true_filter_for_modulo_metadata(), offset_limit(), trivial_projection(), ) @@ -77,15 +78,31 @@ fn get_filter_limit_projection( GetOrchestrator::new( blockfile_provider, dispatcher_handle, - 100, + 1000, test_segments.into(), empty_fetch_log(collection_uuid), - always_true_where_for_modulo_metadata(), + always_true_filter_for_modulo_metadata(), offset_limit(), all_projection(), ) } +async fn bench_routine(input: (System, GetOrchestrator, Vec)) { + let (system, orchestrator, expected_ids) = input; + let output = orchestrator + .run(system) + .await + .expect("Orchestrator should not fail"); + assert_eq!( + output + .records + .into_iter() + .map(|record| record.id) + .collect::>(), + expected_ids + ); +} + fn bench_get(criterion: &mut Criterion) { let runtime = tokio_multi_thread(); let test_segments = runtime.block_on(sift1m_segments()); @@ -128,48 +145,33 @@ fn bench_get(criterion: &mut Criterion) { ) }; - let routine = |(system, orchestrator, expected): (System, GetOrchestrator, Vec)| async move { - let output = orchestrator - .run(system) - .await - .expect("Orchestrator should not fail"); - assert_eq!( - output - .records - .into_iter() - .map(|record| record.id) - .collect::>(), - expected - ); - }; - bench_run( "test-trivial-get", criterion, &runtime, trivial_get_setup, - routine, + bench_routine, ); bench_run( "test-get-filter", criterion, &runtime, get_filter_setup, - routine, + bench_routine, ); bench_run( "test-get-filter-limit", criterion, &runtime, get_filter_limit_setup, - routine, + bench_routine, ); bench_run( "test-get-filter-limit-projection", criterion, &runtime, get_filter_limit_projection_setup, - routine, + bench_routine, ); } criterion_group!(benches, bench_get); diff --git a/rust/worker/benches/load.rs b/rust/worker/benches/load.rs index 6000acadd150..50c1bdfeb500 100644 --- a/rust/worker/benches/load.rs +++ b/rust/worker/benches/load.rs @@ -73,7 +73,7 @@ pub fn trivial_filter() -> FilterOperator { } } -pub fn always_true_where_for_modulo_metadata() -> FilterOperator { +pub fn always_true_filter_for_modulo_metadata() -> FilterOperator { FilterOperator { query_ids: None, where_clause: Some(Where::conjunction(vec![ diff --git a/rust/worker/benches/query.rs b/rust/worker/benches/query.rs new file mode 100644 index 000000000000..2d4bc4fe6719 --- /dev/null +++ b/rust/worker/benches/query.rs @@ -0,0 +1,218 @@ +#[allow(dead_code)] +mod load; + +use std::collections::HashSet; + +use chroma_benchmark::{ + benchmark::{bench_run, tokio_multi_thread}, + datasets::sift::Sift1MData, +}; +use chroma_config::Configurable; +use criterion::{criterion_group, criterion_main, Criterion}; +use futures::{stream, StreamExt, TryStreamExt}; +use load::{ + all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments, + trivial_filter, +}; +use rand::{seq::SliceRandom, thread_rng}; +use worker::{ + config::RootConfig, + execution::{ + dispatcher::Dispatcher, + operators::{knn::KnnOperator, knn_projection::KnnProjectionOperator}, + orchestration::{ + knn::KnnOrchestrator, + knn_filter::{KnnFilterOrchestrator, KnnFilterOutput}, + }, + }, + segment::test::TestSegment, + system::{ComponentHandle, System}, +}; + +fn trivial_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + trivial_filter(), + ) +} + +fn always_true_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + ) +} + +fn knn( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, + knn_filter_output: KnnFilterOutput, + query: Vec, +) -> KnnOrchestrator { + KnnOrchestrator::new( + test_segments.blockfile_provider.clone(), + dispatcher_handle.clone(), + 1000, + knn_filter_output.clone(), + KnnOperator { + embedding: query, + fetch: Sift1MData::k() as u32, + }, + KnnProjectionOperator { + projection: all_projection(), + distance: true, + }, + ) +} + +async fn bench_routine( + input: ( + System, + KnnFilterOrchestrator, + impl Fn(KnnFilterOutput) -> Vec<(KnnOrchestrator, Vec)>, + ), +) { + let (system, knn_filter, knn_constructor) = input; + let knn_filter_output = knn_filter + .run(system.clone()) + .await + .expect("Orchestrator should not fail"); + let (knns, expected): (Vec<_>, Vec<_>) = knn_constructor(knn_filter_output).into_iter().unzip(); + let results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone()))) + .buffered(32) + .try_collect::>() + .await + .expect("Orchestrators should not fail"); + results + .into_iter() + .map(|result| { + result + .records + .into_iter() + .map(|record| record.record.id.parse()) + // .collect::>() + .collect::, _>>() + .expect("Record id should be parsable to u32") + }) + .zip(expected) + .for_each(|(got, expected)| { + let expected_set: HashSet<_> = HashSet::from_iter(expected); + let recall = got + .into_iter() + .filter(|id| expected_set.contains(id)) + .count() as f64 + / expected_set.len() as f64; + assert!(recall > 0.9); + }); +} + +fn bench_query(criterion: &mut Criterion) { + let runtime = tokio_multi_thread(); + let test_segments = runtime.block_on(sift1m_segments()); + + let config = RootConfig::default(); + let system = System::default(); + let dispatcher = runtime + .block_on(Dispatcher::try_from_config( + &config.query_service.dispatcher, + )) + .expect("Should be able to initialize dispatcher"); + let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) }); + + let mut sift1m = runtime + .block_on(Sift1MData::init()) + .expect("Should be able to download Sift1M data"); + let mut sift1m_queries = runtime + .block_on(sift1m.query()) + .expect("Should be able to load Sift1M queries"); + + sift1m_queries.as_mut_slice().shuffle(&mut thread_rng()); + + let trivial_knn_setup = || { + ( + system.clone(), + trivial_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, expected)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + expected.clone(), + ) + }) + .collect() + }, + ) + }; + + let filtered_knn_setup = || { + ( + system.clone(), + always_true_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, expected)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + expected.clone(), + ) + }) + .collect() + }, + ) + }; + + bench_run( + "test-trivial-knn", + criterion, + &runtime, + trivial_knn_setup, + bench_routine, + ); + + bench_run( + "test-filtered-knn", + criterion, + &runtime, + filtered_knn_setup, + bench_routine, + ); +} +criterion_group!(benches, bench_query); +criterion_main!(benches);