From 9df14a71b12c815634afefd1c1985f1b11160e6a Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Tue, 17 Dec 2024 13:32:16 -0800 Subject: [PATCH] Fix dataset decoding --- rust/benchmark/src/datasets/sift.rs | 83 +++++++++++++++++---------- rust/worker/benches/get.rs | 67 ++++++++++++++++------ rust/worker/benches/load.rs | 22 +++++++ rust/worker/benches/query.rs | 89 +++++++++++++++++++---------- 4 files changed, 183 insertions(+), 78 deletions(-) diff --git a/rust/benchmark/src/datasets/sift.rs b/rust/benchmark/src/datasets/sift.rs index d18a6cdcd8a6..d55fe1eaa04b 100644 --- a/rust/benchmark/src/datasets/sift.rs +++ b/rust/benchmark/src/datasets/sift.rs @@ -20,7 +20,7 @@ pub struct Sift1MData { impl Sift1MData { pub async fn init() -> Result { let base = get_or_populate_cached_dataset_file( - "gist", + "sift1m", "base.fvecs", None, |mut writer| async move { @@ -45,7 +45,7 @@ impl Sift1MData { }, ).await?; let query = get_or_populate_cached_dataset_file( - "gist", + "sift1m", "query.fvecs", None, |mut writer| async move { @@ -70,8 +70,8 @@ impl Sift1MData { }, ).await?; let ground = get_or_populate_cached_dataset_file( - "gist", - "groundtruth.fvecs", + "sift1m", + "groundtruth.ivecs", None, |mut writer| async move { let client = reqwest::Client::new(); @@ -134,43 +134,64 @@ impl Sift1MData { return Ok(Vec::new()); } - let start = SeekFrom::Start( - (size_of::() + lower_bound * Self::dimension() * size_of::()) as u64, - ); + let vector_size = size_of::() + Self::dimension() * size_of::(); + + let start = SeekFrom::Start((lower_bound * vector_size) as u64); self.base.seek(start).await?; let batch_size = upper_bound - lower_bound; - let mut base_bytes = vec![0; batch_size * Self::dimension() * size_of::()]; + let mut base_bytes = vec![0; batch_size * vector_size]; self.base.read_exact(&mut base_bytes).await?; - let embedding_f32s: Vec<_> = base_bytes - .chunks(size_of::()) - .map(|c| f32::from_le_bytes(c.try_into().unwrap())) - .collect(); - Ok(embedding_f32s - .chunks(Self::dimension()) - .map(|embedding| embedding.to_vec()) - .collect()) + read_raw_vec(&base_bytes, |bytes| { + Ok(f32::from_le_bytes(bytes.try_into()?)) + }) } pub async fn query(&mut self) -> Result, Vec)>> { let mut query_bytes = Vec::new(); self.query.read_to_end(&mut query_bytes).await?; - let (_, embeddings_bytes) = query_bytes.split_at(size_of::()); - let embedding_f32s: Vec<_> = embeddings_bytes - .chunks(size_of::()) - .map(|c| f32::from_le_bytes(c.try_into().unwrap())) - .collect(); + let queries = read_raw_vec(&query_bytes, |bytes| { + Ok(f32::from_le_bytes(bytes.try_into()?)) + })?; let mut ground_bytes = Vec::new(); self.ground.read_to_end(&mut ground_bytes).await?; - let (_, embeddings_bytes) = query_bytes.split_at(size_of::()); - let ground_u32s: Vec<_> = embeddings_bytes - .chunks(size_of::()) - .map(|c| u32::from_le_bytes(c.try_into().unwrap())) - .collect(); - Ok(embedding_f32s - .chunks(Self::dimension()) - .zip(ground_u32s.chunks(Self::k())) - .map(|(embedding, ground)| (embedding.to_vec(), ground.to_vec())) - .collect()) + let grounds = read_raw_vec(&ground_bytes, |bytes| { + Ok(u32::from_le_bytes(bytes.try_into()?)) + })?; + if queries.len() != grounds.len() { + return Err(anyhow!( + "Queries and grounds count mismatch: {} != {}", + queries.len(), + grounds.len() + )); + } + Ok(queries.into_iter().zip(grounds).collect()) + } +} + +fn read_raw_vec( + raw_bytes: &[u8], + convert_from_bytes: impl Fn(&[u8]) -> Result, +) -> Result>> { + let mut result = Vec::new(); + let mut bytes = raw_bytes; + while !bytes.is_empty() { + let (dimension_bytes, rem_bytes) = bytes.split_at(size_of::()); + let dimension = u32::from_le_bytes(dimension_bytes.try_into()?); + let (embedding_bytes, rem_bytes) = rem_bytes.split_at(dimension as usize * size_of::()); + let embedding = embedding_bytes + .chunks(size_of::()) + .map(&convert_from_bytes) + .collect::>>()?; + if embedding.len() != dimension as usize { + return Err(anyhow!( + "Embedding dimension mismatch: {} != {}", + embedding.len(), + dimension + )); + } + result.push(embedding); + bytes = rem_bytes; } + Ok(result) } diff --git a/rust/worker/benches/get.rs b/rust/worker/benches/get.rs index 281318983097..04290c38500d 100644 --- a/rust/worker/benches/get.rs +++ b/rust/worker/benches/get.rs @@ -5,8 +5,9 @@ use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread}; use chroma_config::Configurable; use criterion::{criterion_group, criterion_main, Criterion}; use load::{ - all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, - sift1m_segments, trivial_filter, trivial_limit, trivial_projection, + all_projection, always_false_filter_for_modulo_metadata, + always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, sift1m_segments, + trivial_filter, trivial_limit, trivial_projection, }; use worker::{ config::RootConfig, @@ -33,7 +34,25 @@ fn trivial_get( ) } -fn get_filter( +fn get_false_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_false_filter_for_modulo_metadata(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_true_filter( test_segments: TestSegment, dispatcher_handle: ComponentHandle, ) -> GetOrchestrator { @@ -51,7 +70,7 @@ fn get_filter( ) } -fn get_filter_limit( +fn get_true_filter_limit( test_segments: TestSegment, dispatcher_handle: ComponentHandle, ) -> GetOrchestrator { @@ -69,7 +88,7 @@ fn get_filter_limit( ) } -fn get_filter_limit_projection( +fn get_true_filter_limit_projection( test_segments: TestSegment, dispatcher_handle: ComponentHandle, ) -> GetOrchestrator { @@ -123,24 +142,31 @@ fn bench_get(criterion: &mut Criterion) { (0..100).map(|id| id.to_string()).collect(), ) }; - let get_filter_setup = || { + let get_false_filter_setup = || { + ( + system.clone(), + get_false_filter(test_segments.clone(), dispatcher_handle.clone()), + Vec::new(), + ) + }; + let get_true_filter_setup = || { ( system.clone(), - get_filter(test_segments.clone(), dispatcher_handle.clone()), + get_true_filter(test_segments.clone(), dispatcher_handle.clone()), (0..100).map(|id| id.to_string()).collect(), ) }; - let get_filter_limit_setup = || { + let get_true_filter_limit_setup = || { ( system.clone(), - get_filter_limit(test_segments.clone(), dispatcher_handle.clone()), + get_true_filter_limit(test_segments.clone(), dispatcher_handle.clone()), (100..200).map(|id| id.to_string()).collect(), ) }; - let get_filter_limit_projection_setup = || { + let get_true_filter_limit_projection_setup = || { ( system.clone(), - get_filter_limit_projection(test_segments.clone(), dispatcher_handle.clone()), + get_true_filter_limit_projection(test_segments.clone(), dispatcher_handle.clone()), (100..200).map(|id| id.to_string()).collect(), ) }; @@ -153,24 +179,31 @@ fn bench_get(criterion: &mut Criterion) { bench_routine, ); bench_run( - "test-get-filter", + "test-get-false-filter", + criterion, + &runtime, + get_false_filter_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter", criterion, &runtime, - get_filter_setup, + get_true_filter_setup, bench_routine, ); bench_run( - "test-get-filter-limit", + "test-get-true-filter-limit", criterion, &runtime, - get_filter_limit_setup, + get_true_filter_limit_setup, bench_routine, ); bench_run( - "test-get-filter-limit-projection", + "test-get-true-filter-limit-projection", criterion, &runtime, - get_filter_limit_projection_setup, + get_true_filter_limit_projection_setup, bench_routine, ); } diff --git a/rust/worker/benches/load.rs b/rust/worker/benches/load.rs index 50c1bdfeb500..aef7f8e05511 100644 --- a/rust/worker/benches/load.rs +++ b/rust/worker/benches/load.rs @@ -73,6 +73,28 @@ pub fn trivial_filter() -> FilterOperator { } } +pub fn always_false_filter_for_modulo_metadata() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: Some(Where::disjunction(vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "is_even".to_string(), + comparison: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Bool(vec![false, true]), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "modulo_3".to_string(), + comparison: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Int(vec![0, 1, 2]), + ), + }), + ])), + } +} + pub fn always_true_filter_for_modulo_metadata() -> FilterOperator { FilterOperator { query_ids: None, diff --git a/rust/worker/benches/query.rs b/rust/worker/benches/query.rs index 2d4bc4fe6719..589cf27ba343 100644 --- a/rust/worker/benches/query.rs +++ b/rust/worker/benches/query.rs @@ -1,8 +1,6 @@ #[allow(dead_code)] mod load; -use std::collections::HashSet; - use chroma_benchmark::{ benchmark::{bench_run, tokio_multi_thread}, datasets::sift::Sift1MData, @@ -11,8 +9,8 @@ use chroma_config::Configurable; use criterion::{criterion_group, criterion_main, Criterion}; use futures::{stream, StreamExt, TryStreamExt}; use load::{ - all_projection, always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments, - trivial_filter, + all_projection, always_false_filter_for_modulo_metadata, + always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments, trivial_filter, }; use rand::{seq::SliceRandom, thread_rng}; use worker::{ @@ -65,6 +63,24 @@ fn always_true_knn_filter( ) } +fn always_false_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_false_filter_for_modulo_metadata(), + ) +} + fn knn( test_segments: TestSegment, dispatcher_handle: ComponentHandle, @@ -99,33 +115,14 @@ async fn bench_routine( .run(system.clone()) .await .expect("Orchestrator should not fail"); - let (knns, expected): (Vec<_>, Vec<_>) = knn_constructor(knn_filter_output).into_iter().unzip(); - let results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone()))) + let (knns, _expected): (Vec<_>, Vec<_>) = + knn_constructor(knn_filter_output).into_iter().unzip(); + let _results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone()))) .buffered(32) .try_collect::>() .await .expect("Orchestrators should not fail"); - results - .into_iter() - .map(|result| { - result - .records - .into_iter() - .map(|record| record.record.id.parse()) - // .collect::>() - .collect::, _>>() - .expect("Record id should be parsable to u32") - }) - .zip(expected) - .for_each(|(got, expected)| { - let expected_set: HashSet<_> = HashSet::from_iter(expected); - let recall = got - .into_iter() - .filter(|id| expected_set.contains(id)) - .count() as f64 - / expected_set.len() as f64; - assert!(recall > 0.9); - }); + // TODO: verify recall } fn bench_query(criterion: &mut Criterion) { @@ -174,7 +171,7 @@ fn bench_query(criterion: &mut Criterion) { ) }; - let filtered_knn_setup = || { + let true_filter_knn_setup = || { ( system.clone(), always_true_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), @@ -198,6 +195,30 @@ fn bench_query(criterion: &mut Criterion) { ) }; + let false_filter_knn_setup = || { + ( + system.clone(), + always_false_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, _)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + Vec::new(), + ) + }) + .collect() + }, + ) + }; + bench_run( "test-trivial-knn", criterion, @@ -207,10 +228,18 @@ fn bench_query(criterion: &mut Criterion) { ); bench_run( - "test-filtered-knn", + "test-true-filter-knn", + criterion, + &runtime, + true_filter_knn_setup, + bench_routine, + ); + + bench_run( + "test-false-filter-knn", criterion, &runtime, - filtered_knn_setup, + false_filter_knn_setup, bench_routine, ); }