diff --git a/Cargo.lock b/Cargo.lock index 9e862373306..cbbf913ada7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6792,6 +6792,7 @@ dependencies = [ "figment", "flatbuffers", "futures", + "indicatif", "k8s-openapi", "kube", "murmur3", diff --git a/rust/benchmark/src/datasets/gist.rs b/rust/benchmark/src/datasets/gist.rs index d511f298d79..ad26934e50b 100644 --- a/rust/benchmark/src/datasets/gist.rs +++ b/rust/benchmark/src/datasets/gist.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use anyhow::Ok; +use anyhow::Result; use tokio::io::{AsyncReadExt, BufReader}; use super::{ @@ -20,7 +20,7 @@ impl RecordDataset for GistDataset { const DISPLAY_NAME: &'static str = "Gist"; const NAME: &'static str = "gist"; - async fn init() -> anyhow::Result { + async fn init() -> Result { // TODO(Sanket): Download file if it doesn't exist. // move file from downloads to cached path. let current_path = "/Users/sanketkedia/Downloads/siftsmall/siftsmall_base.fvecs"; diff --git a/rust/benchmark/src/datasets/mod.rs b/rust/benchmark/src/datasets/mod.rs index 60d5e35e486..3ceef402298 100644 --- a/rust/benchmark/src/datasets/mod.rs +++ b/rust/benchmark/src/datasets/mod.rs @@ -4,4 +4,5 @@ pub mod util; pub mod gist; pub mod ms_marco_queries; pub mod scidocs; +pub mod sift; pub mod wikipedia; diff --git a/rust/benchmark/src/datasets/sift.rs b/rust/benchmark/src/datasets/sift.rs new file mode 100644 index 00000000000..d55fe1eaa04 --- /dev/null +++ b/rust/benchmark/src/datasets/sift.rs @@ -0,0 +1,197 @@ +use std::{ + io::SeekFrom, + ops::{Bound, RangeBounds}, +}; + +use anyhow::{anyhow, Ok, Result}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader}, +}; + +use super::util::get_or_populate_cached_dataset_file; + +pub struct Sift1MData { + pub base: BufReader, + pub query: BufReader, + pub ground: BufReader, +} + +impl Sift1MData { + pub async fn init() -> Result { + let base = get_or_populate_cached_dataset_file( + "sift1m", + "base.fvecs", + None, + |mut writer| async move { + let client = reqwest::Client::new(); + let response = client + .get( + "https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_base.fvecs", + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download Sift1M base data, got status code {}", + response.status() + )); + } + + writer.write_all(&response.bytes().await?).await?; + + Ok(()) + }, + ).await?; + let query = get_or_populate_cached_dataset_file( + "sift1m", + "query.fvecs", + None, + |mut writer| async move { + let client = reqwest::Client::new(); + let response = client + .get( + "https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_query.fvecs", + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download Sift1M query data, got status code {}", + response.status() + )); + } + + writer.write_all(&response.bytes().await?).await?; + + Ok(()) + }, + ).await?; + let ground = get_or_populate_cached_dataset_file( + "sift1m", + "groundtruth.ivecs", + None, + |mut writer| async move { + let client = reqwest::Client::new(); + let response = client + .get( + "https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_groundtruth.ivecs", + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download Sift1M ground data, got status code {}", + response.status() + )); + } + + writer.write_all(&response.bytes().await?).await?; + + Ok(()) + }, + ).await?; + Ok(Self { + base: BufReader::new(File::open(base).await?), + query: BufReader::new(File::open(query).await?), + ground: BufReader::new(File::open(ground).await?), + }) + } + + pub fn collection_size() -> usize { + 1000000 + } + + pub fn query_size() -> usize { + 10000 + } + + pub fn dimension() -> usize { + 128 + } + + pub fn k() -> usize { + 100 + } + + pub async fn data_range(&mut self, range: impl RangeBounds) -> Result>> { + let lower_bound = match range.start_bound() { + Bound::Included(include) => *include, + Bound::Excluded(exclude) => exclude + 1, + Bound::Unbounded => 0, + }; + let upper_bound = match range.end_bound() { + Bound::Included(include) => include + 1, + Bound::Excluded(exclude) => *exclude, + Bound::Unbounded => usize::MAX, + } + .min(Self::collection_size()); + + if lower_bound >= upper_bound { + return Ok(Vec::new()); + } + + let vector_size = size_of::() + Self::dimension() * size_of::(); + + let start = SeekFrom::Start((lower_bound * vector_size) as u64); + self.base.seek(start).await?; + let batch_size = upper_bound - lower_bound; + let mut base_bytes = vec![0; batch_size * vector_size]; + self.base.read_exact(&mut base_bytes).await?; + read_raw_vec(&base_bytes, |bytes| { + Ok(f32::from_le_bytes(bytes.try_into()?)) + }) + } + + pub async fn query(&mut self) -> Result, Vec)>> { + let mut query_bytes = Vec::new(); + self.query.read_to_end(&mut query_bytes).await?; + let queries = read_raw_vec(&query_bytes, |bytes| { + Ok(f32::from_le_bytes(bytes.try_into()?)) + })?; + + let mut ground_bytes = Vec::new(); + self.ground.read_to_end(&mut ground_bytes).await?; + let grounds = read_raw_vec(&ground_bytes, |bytes| { + Ok(u32::from_le_bytes(bytes.try_into()?)) + })?; + if queries.len() != grounds.len() { + return Err(anyhow!( + "Queries and grounds count mismatch: {} != {}", + queries.len(), + grounds.len() + )); + } + Ok(queries.into_iter().zip(grounds).collect()) + } +} + +fn read_raw_vec( + raw_bytes: &[u8], + convert_from_bytes: impl Fn(&[u8]) -> Result, +) -> Result>> { + let mut result = Vec::new(); + let mut bytes = raw_bytes; + while !bytes.is_empty() { + let (dimension_bytes, rem_bytes) = bytes.split_at(size_of::()); + let dimension = u32::from_le_bytes(dimension_bytes.try_into()?); + let (embedding_bytes, rem_bytes) = rem_bytes.split_at(dimension as usize * size_of::()); + let embedding = embedding_bytes + .chunks(size_of::()) + .map(&convert_from_bytes) + .collect::>>()?; + if embedding.len() != dimension as usize { + return Err(anyhow!( + "Embedding dimension mismatch: {} != {}", + embedding.len(), + dimension + )); + } + result.push(embedding); + bytes = rem_bytes; + } + Ok(result) +} diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index d0541f65869..800093d306b 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -61,13 +61,14 @@ chroma-distance = { workspace = true } random-port = "0.1.1" serial_test = "3.2.0" +criterion = { workspace = true } +indicatif = { workspace = true } +proptest = { workspace = true } +proptest-state-machine = { workspace = true } +shuttle = { workspace = true } rand = { workspace = true } rand_xorshift = { workspace = true } tempfile = { workspace = true } -shuttle = { workspace = true } -proptest = { workspace = true } -proptest-state-machine = { workspace = true } -criterion = { workspace = true } chroma-benchmark = { workspace = true } @@ -75,10 +76,18 @@ chroma-benchmark = { workspace = true } name = "filter" harness = false +[[bench]] +name = "get" +harness = false + [[bench]] name = "limit" harness = false +[[bench]] +name = "query" +harness = false + [[bench]] name = "spann" harness = false diff --git a/rust/worker/benches/filter.rs b/rust/worker/benches/filter.rs index 4accb804818..57f54644b26 100644 --- a/rust/worker/benches/filter.rs +++ b/rust/worker/benches/filter.rs @@ -9,7 +9,7 @@ use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use worker::execution::operator::Operator; use worker::execution::operators::filter::{FilterInput, FilterOperator}; -use worker::log::test::{upsert_generator, LogGenerator}; +use worker::log::test::upsert_generator; use worker::segment::test::TestSegment; fn baseline_where_clauses() -> Vec<(&'static str, Option)> { @@ -71,14 +71,13 @@ fn baseline_where_clauses() -> Vec<(&'static str, Option)> { fn bench_filter(criterion: &mut Criterion) { let runtime = tokio_multi_thread(); - let logen = LogGenerator { - generator: upsert_generator, - }; for record_count in [1000, 10000, 100000] { let test_segment = runtime.block_on(async { let mut segment = TestSegment::default(); - segment.populate_with_generator(record_count, &logen).await; + segment + .populate_with_generator(record_count, upsert_generator) + .await; segment }); diff --git a/rust/worker/benches/get.rs b/rust/worker/benches/get.rs new file mode 100644 index 00000000000..04290c38500 --- /dev/null +++ b/rust/worker/benches/get.rs @@ -0,0 +1,211 @@ +#[allow(dead_code)] +mod load; + +use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread}; +use chroma_config::Configurable; +use criterion::{criterion_group, criterion_main, Criterion}; +use load::{ + all_projection, always_false_filter_for_modulo_metadata, + always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, sift1m_segments, + trivial_filter, trivial_limit, trivial_projection, +}; +use worker::{ + config::RootConfig, + execution::{dispatcher::Dispatcher, orchestration::get::GetOrchestrator}, + segment::test::TestSegment, + system::{ComponentHandle, System}, +}; + +fn trivial_get( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + trivial_filter(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_false_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_false_filter_for_modulo_metadata(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_true_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_true_filter_limit( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + offset_limit(), + trivial_projection(), + ) +} + +fn get_true_filter_limit_projection( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + offset_limit(), + all_projection(), + ) +} + +async fn bench_routine(input: (System, GetOrchestrator, Vec)) { + let (system, orchestrator, expected_ids) = input; + let output = orchestrator + .run(system) + .await + .expect("Orchestrator should not fail"); + assert_eq!( + output + .records + .into_iter() + .map(|record| record.id) + .collect::>(), + expected_ids + ); +} + +fn bench_get(criterion: &mut Criterion) { + let runtime = tokio_multi_thread(); + let test_segments = runtime.block_on(sift1m_segments()); + + let config = RootConfig::default(); + let system = System::default(); + let dispatcher = runtime + .block_on(Dispatcher::try_from_config( + &config.query_service.dispatcher, + )) + .expect("Should be able to initialize dispatcher"); + let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) }); + + let trivial_get_setup = || { + ( + system.clone(), + trivial_get(test_segments.clone(), dispatcher_handle.clone()), + (0..100).map(|id| id.to_string()).collect(), + ) + }; + let get_false_filter_setup = || { + ( + system.clone(), + get_false_filter(test_segments.clone(), dispatcher_handle.clone()), + Vec::new(), + ) + }; + let get_true_filter_setup = || { + ( + system.clone(), + get_true_filter(test_segments.clone(), dispatcher_handle.clone()), + (0..100).map(|id| id.to_string()).collect(), + ) + }; + let get_true_filter_limit_setup = || { + ( + system.clone(), + get_true_filter_limit(test_segments.clone(), dispatcher_handle.clone()), + (100..200).map(|id| id.to_string()).collect(), + ) + }; + let get_true_filter_limit_projection_setup = || { + ( + system.clone(), + get_true_filter_limit_projection(test_segments.clone(), dispatcher_handle.clone()), + (100..200).map(|id| id.to_string()).collect(), + ) + }; + + bench_run( + "test-trivial-get", + criterion, + &runtime, + trivial_get_setup, + bench_routine, + ); + bench_run( + "test-get-false-filter", + criterion, + &runtime, + get_false_filter_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter", + criterion, + &runtime, + get_true_filter_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter-limit", + criterion, + &runtime, + get_true_filter_limit_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter-limit-projection", + criterion, + &runtime, + get_true_filter_limit_projection_setup, + bench_routine, + ); +} +criterion_group!(benches, bench_get); +criterion_main!(benches); diff --git a/rust/worker/benches/limit.rs b/rust/worker/benches/limit.rs index a19e481712c..7a13c174acc 100644 --- a/rust/worker/benches/limit.rs +++ b/rust/worker/benches/limit.rs @@ -4,21 +4,20 @@ use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use worker::execution::operator::Operator; use worker::execution::operators::limit::{LimitInput, LimitOperator}; -use worker::log::test::{upsert_generator, LogGenerator}; +use worker::log::test::upsert_generator; use worker::segment::test::TestSegment; const FETCH: usize = 100; fn bench_limit(criterion: &mut Criterion) { let runtime = tokio_multi_thread(); - let logen = LogGenerator { - generator: upsert_generator, - }; for record_count in [1000, 10000, 100000] { let test_segment = runtime.block_on(async { let mut segment = TestSegment::default(); - segment.populate_with_generator(record_count, &logen).await; + segment + .populate_with_generator(record_count, upsert_generator) + .await; segment }); diff --git a/rust/worker/benches/load.rs b/rust/worker/benches/load.rs new file mode 100644 index 00000000000..aef7f8e0551 --- /dev/null +++ b/rust/worker/benches/load.rs @@ -0,0 +1,148 @@ +use chroma_benchmark::datasets::sift::Sift1MData; +use chroma_types::{ + Chunk, CollectionUuid, DirectWhereComparison, LogRecord, MetadataSetValue, Operation, + OperationRecord, SetOperator, Where, WhereComparison, +}; +use indicatif::ProgressIterator; +use worker::{ + execution::operators::{ + fetch_log::FetchLogOperator, filter::FilterOperator, limit::LimitOperator, + projection::ProjectionOperator, + }, + log::{ + log::{InMemoryLog, Log}, + test::modulo_metadata, + }, + segment::test::TestSegment, +}; + +const DATA_CHUNK_SIZE: usize = 10000; + +pub async fn sift1m_segments() -> TestSegment { + let mut segments = TestSegment::default(); + let mut sift1m = Sift1MData::init() + .await + .expect("Should be able to download Sift1M data"); + + for chunk_start in (0..Sift1MData::collection_size()) + .step_by(DATA_CHUNK_SIZE) + .progress() + .with_message("Loading Sift1M Data") + { + let embedding_chunk = sift1m + .data_range(chunk_start..(chunk_start + DATA_CHUNK_SIZE)) + .await + .expect("Should be able to decode data chunk"); + + let log_records = embedding_chunk + .into_iter() + .enumerate() + .map(|(index, embedding)| LogRecord { + log_offset: (chunk_start + index) as i64, + record: OperationRecord { + id: (chunk_start + index).to_string(), + embedding: Some(embedding), + encoding: None, + metadata: Some(modulo_metadata(chunk_start + index)), + document: None, + operation: Operation::Add, + }, + }) + .collect::>(); + segments + .compact_log(Chunk::new(log_records.into()), chunk_start) + .await; + } + segments +} + +pub fn empty_fetch_log(collection_uuid: CollectionUuid) -> FetchLogOperator { + FetchLogOperator { + log_client: Log::InMemory(InMemoryLog::default()).into(), + batch_size: 100, + start_log_offset_id: 0, + maximum_fetch_count: Some(0), + collection_uuid, + } +} + +pub fn trivial_filter() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: None, + } +} + +pub fn always_false_filter_for_modulo_metadata() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: Some(Where::disjunction(vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "is_even".to_string(), + comparison: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Bool(vec![false, true]), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "modulo_3".to_string(), + comparison: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Int(vec![0, 1, 2]), + ), + }), + ])), + } +} + +pub fn always_true_filter_for_modulo_metadata() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: Some(Where::conjunction(vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "is_even".to_string(), + comparison: WhereComparison::Set( + SetOperator::In, + MetadataSetValue::Bool(vec![false, true]), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "modulo_3".to_string(), + comparison: WhereComparison::Set( + SetOperator::In, + MetadataSetValue::Int(vec![0, 1, 2]), + ), + }), + ])), + } +} + +pub fn trivial_limit() -> LimitOperator { + LimitOperator { + skip: 0, + fetch: Some(100), + } +} + +pub fn offset_limit() -> LimitOperator { + LimitOperator { + skip: 100, + fetch: Some(100), + } +} + +pub fn trivial_projection() -> ProjectionOperator { + ProjectionOperator { + document: false, + embedding: false, + metadata: false, + } +} + +pub fn all_projection() -> ProjectionOperator { + ProjectionOperator { + document: true, + embedding: true, + metadata: true, + } +} diff --git a/rust/worker/benches/query.rs b/rust/worker/benches/query.rs new file mode 100644 index 00000000000..589cf27ba34 --- /dev/null +++ b/rust/worker/benches/query.rs @@ -0,0 +1,247 @@ +#[allow(dead_code)] +mod load; + +use chroma_benchmark::{ + benchmark::{bench_run, tokio_multi_thread}, + datasets::sift::Sift1MData, +}; +use chroma_config::Configurable; +use criterion::{criterion_group, criterion_main, Criterion}; +use futures::{stream, StreamExt, TryStreamExt}; +use load::{ + all_projection, always_false_filter_for_modulo_metadata, + always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments, trivial_filter, +}; +use rand::{seq::SliceRandom, thread_rng}; +use worker::{ + config::RootConfig, + execution::{ + dispatcher::Dispatcher, + operators::{knn::KnnOperator, knn_projection::KnnProjectionOperator}, + orchestration::{ + knn::KnnOrchestrator, + knn_filter::{KnnFilterOrchestrator, KnnFilterOutput}, + }, + }, + segment::test::TestSegment, + system::{ComponentHandle, System}, +}; + +fn trivial_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + trivial_filter(), + ) +} + +fn always_true_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + ) +} + +fn always_false_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_false_filter_for_modulo_metadata(), + ) +} + +fn knn( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, + knn_filter_output: KnnFilterOutput, + query: Vec, +) -> KnnOrchestrator { + KnnOrchestrator::new( + test_segments.blockfile_provider.clone(), + dispatcher_handle.clone(), + 1000, + knn_filter_output.clone(), + KnnOperator { + embedding: query, + fetch: Sift1MData::k() as u32, + }, + KnnProjectionOperator { + projection: all_projection(), + distance: true, + }, + ) +} + +async fn bench_routine( + input: ( + System, + KnnFilterOrchestrator, + impl Fn(KnnFilterOutput) -> Vec<(KnnOrchestrator, Vec)>, + ), +) { + let (system, knn_filter, knn_constructor) = input; + let knn_filter_output = knn_filter + .run(system.clone()) + .await + .expect("Orchestrator should not fail"); + let (knns, _expected): (Vec<_>, Vec<_>) = + knn_constructor(knn_filter_output).into_iter().unzip(); + let _results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone()))) + .buffered(32) + .try_collect::>() + .await + .expect("Orchestrators should not fail"); + // TODO: verify recall +} + +fn bench_query(criterion: &mut Criterion) { + let runtime = tokio_multi_thread(); + let test_segments = runtime.block_on(sift1m_segments()); + + let config = RootConfig::default(); + let system = System::default(); + let dispatcher = runtime + .block_on(Dispatcher::try_from_config( + &config.query_service.dispatcher, + )) + .expect("Should be able to initialize dispatcher"); + let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) }); + + let mut sift1m = runtime + .block_on(Sift1MData::init()) + .expect("Should be able to download Sift1M data"); + let mut sift1m_queries = runtime + .block_on(sift1m.query()) + .expect("Should be able to load Sift1M queries"); + + sift1m_queries.as_mut_slice().shuffle(&mut thread_rng()); + + let trivial_knn_setup = || { + ( + system.clone(), + trivial_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, expected)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + expected.clone(), + ) + }) + .collect() + }, + ) + }; + + let true_filter_knn_setup = || { + ( + system.clone(), + always_true_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, expected)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + expected.clone(), + ) + }) + .collect() + }, + ) + }; + + let false_filter_knn_setup = || { + ( + system.clone(), + always_false_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, _)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + Vec::new(), + ) + }) + .collect() + }, + ) + }; + + bench_run( + "test-trivial-knn", + criterion, + &runtime, + trivial_knn_setup, + bench_routine, + ); + + bench_run( + "test-true-filter-knn", + criterion, + &runtime, + true_filter_knn_setup, + bench_routine, + ); + + bench_run( + "test-false-filter-knn", + criterion, + &runtime, + false_filter_knn_setup, + bench_routine, + ); +} +criterion_group!(benches, bench_query); +criterion_main!(benches); diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index 79107db3b1d..63b9b524e8b 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -12,7 +12,7 @@ const DEFAULT_CONFIG_PATH: &str = "./chroma_config.yaml"; /// variables take precedence over values in the YAML file. /// By default, it is read from the current working directory, /// with the filename chroma_config.yaml. -pub(crate) struct RootConfig { +pub struct RootConfig { // The root config object wraps the worker config object so that // we can share the same config file between multiple services. pub query_service: QueryServiceConfig, @@ -78,6 +78,12 @@ impl RootConfig { } } +impl Default for RootConfig { + fn default() -> Self { + Self::load() + } +} + #[derive(Deserialize)] /// # Description /// The primary config for the worker service. @@ -89,7 +95,7 @@ impl RootConfig { /// For example, to set my_ip, you would set CHROMA_WORKER__MY_IP. /// Each submodule that needs to be configured from the config object should implement the Configurable trait and /// have its own field in this struct for its Config struct. -pub(crate) struct QueryServiceConfig { +pub struct QueryServiceConfig { pub(crate) service_name: String, pub(crate) otel_endpoint: String, #[allow(dead_code)] @@ -102,7 +108,7 @@ pub(crate) struct QueryServiceConfig { pub(crate) sysdb: crate::sysdb::config::SysDbConfig, pub(crate) storage: chroma_storage::config::StorageConfig, pub(crate) log: crate::log::config::LogConfig, - pub(crate) dispatcher: crate::execution::config::DispatcherConfig, + pub dispatcher: crate::execution::config::DispatcherConfig, pub(crate) blockfile_provider: chroma_blockstore::config::BlockfileProviderConfig, pub(crate) hnsw_provider: chroma_index::config::HnswProviderConfig, } @@ -118,7 +124,7 @@ pub(crate) struct QueryServiceConfig { /// For example, to set my_ip, you would set CHROMA_COMPACTOR__MY_IP. /// Each submodule that needs to be configured from the config object should implement the Configurable trait and /// have its own field in this struct for its Config struct. -pub(crate) struct CompactionServiceConfig { +pub struct CompactionServiceConfig { pub(crate) service_name: String, pub(crate) otel_endpoint: String, pub(crate) my_member_id: String, diff --git a/rust/worker/src/execution/config.rs b/rust/worker/src/execution/config.rs index d8550dc41bc..ada5bf1b6ad 100644 --- a/rust/worker/src/execution/config.rs +++ b/rust/worker/src/execution/config.rs @@ -1,7 +1,7 @@ use serde::Deserialize; #[derive(Deserialize)] -pub(crate) struct DispatcherConfig { +pub struct DispatcherConfig { pub(crate) num_worker_threads: usize, pub(crate) dispatcher_queue_size: usize, pub(crate) worker_queue_size: usize, diff --git a/rust/worker/src/execution/dispatcher.rs b/rust/worker/src/execution/dispatcher.rs index 3ce1b57444f..5f29cc5761e 100644 --- a/rust/worker/src/execution/dispatcher.rs +++ b/rust/worker/src/execution/dispatcher.rs @@ -51,7 +51,7 @@ use tracing::{trace_span, Instrument, Span}; coarser work-stealing, and other optimizations. */ #[derive(Debug)] -pub(crate) struct Dispatcher { +pub struct Dispatcher { task_queue: Vec, waiters: Vec, n_worker_threads: usize, diff --git a/rust/worker/src/execution/mod.rs b/rust/worker/src/execution/mod.rs index 5c6ca5567b7..91a2f089e27 100644 --- a/rust/worker/src/execution/mod.rs +++ b/rust/worker/src/execution/mod.rs @@ -1,8 +1,8 @@ pub(crate) mod config; -pub(crate) mod dispatcher; -pub(crate) mod orchestration; mod worker_thread; // Required for benchmark +pub mod dispatcher; pub mod operator; pub mod operators; +pub mod orchestration; diff --git a/rust/worker/src/execution/operators/fetch_log.rs b/rust/worker/src/execution/operators/fetch_log.rs index 22e3198929d..4feb421da3d 100644 --- a/rust/worker/src/execution/operators/fetch_log.rs +++ b/rust/worker/src/execution/operators/fetch_log.rs @@ -32,7 +32,7 @@ use crate::{ /// It should be run at the start of an orchestrator to get the latest data of a collection #[derive(Clone, Debug)] pub struct FetchLogOperator { - pub(crate) log_client: Box, + pub log_client: Box, pub batch_size: u32, pub start_log_offset_id: u32, pub maximum_fetch_count: Option, @@ -126,20 +126,20 @@ mod tests { fn setup_in_memory_log() -> (CollectionUuid, Box) { let collection_id = CollectionUuid::new(); let mut in_memory_log = InMemoryLog::new(); - let generator = LogGenerator { - generator: upsert_generator, - }; - generator.generate_vec(0..10).into_iter().for_each(|log| { - in_memory_log.add_log( - collection_id, - InternalLogRecord { + upsert_generator + .generate_vec(0..10) + .into_iter() + .for_each(|log| { + in_memory_log.add_log( collection_id, - log_offset: log.log_offset, - log_ts: log.log_offset, - record: log, - }, - ) - }); + InternalLogRecord { + collection_id, + log_offset: log.log_offset, + log_ts: log.log_offset, + record: log, + }, + ) + }); (collection_id, Box::new(Log::InMemory(in_memory_log))) } diff --git a/rust/worker/src/execution/operators/filter.rs b/rust/worker/src/execution/operators/filter.rs index 635636b1137..8296a82802c 100644 --- a/rust/worker/src/execution/operators/filter.rs +++ b/rust/worker/src/execution/operators/filter.rs @@ -515,12 +515,11 @@ mod tests { /// - Compacted: Delete [1..=10] deletion, add [11..=50] async fn setup_filter_input() -> FilterInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: add_delete_generator, - }; - test_segment.populate_with_generator(60, &generator).await; + test_segment + .populate_with_generator(60, add_delete_generator) + .await; FilterInput { - logs: generator.generate_chunk(61..=120), + logs: add_delete_generator.generate_chunk(61..=120), blockfile_provider: test_segment.blockfile_provider, metadata_segment: test_segment.metadata_segment, record_segment: test_segment.record_segment, diff --git a/rust/worker/src/execution/operators/knn_log.rs b/rust/worker/src/execution/operators/knn_log.rs index aee11654529..1db07266b2a 100644 --- a/rust/worker/src/execution/operators/knn_log.rs +++ b/rust/worker/src/execution/operators/knn_log.rs @@ -143,11 +143,8 @@ mod tests { log_offset_ids: SignedRoaringBitmap, ) -> KnnLogInput { let test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; KnnLogInput { - logs: generator.generate_chunk(1..=100), + logs: upsert_generator.generate_chunk(1..=100), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, distance_function: metric, diff --git a/rust/worker/src/execution/operators/knn_projection.rs b/rust/worker/src/execution/operators/knn_projection.rs index ee39aaf59ed..0b3cd9dc2fc 100644 --- a/rust/worker/src/execution/operators/knn_projection.rs +++ b/rust/worker/src/execution/operators/knn_projection.rs @@ -145,12 +145,11 @@ mod tests { record_distances: Vec, ) -> KnnProjectionInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; - test_segment.populate_with_generator(100, &generator).await; + test_segment + .populate_with_generator(100, upsert_generator) + .await; KnnProjectionInput { - logs: generator.generate_chunk(81..=120), + logs: upsert_generator.generate_chunk(81..=120), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, record_distances, diff --git a/rust/worker/src/execution/operators/limit.rs b/rust/worker/src/execution/operators/limit.rs index b5709dcddbd..7fa04855943 100644 --- a/rust/worker/src/execution/operators/limit.rs +++ b/rust/worker/src/execution/operators/limit.rs @@ -304,12 +304,11 @@ mod tests { compact_offset_ids: SignedRoaringBitmap, ) -> LimitInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; - test_segment.populate_with_generator(100, &generator).await; + test_segment + .populate_with_generator(100, upsert_generator) + .await; LimitInput { - logs: generator.generate_chunk(31..=60), + logs: upsert_generator.generate_chunk(31..=60), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, log_offset_ids, diff --git a/rust/worker/src/execution/operators/projection.rs b/rust/worker/src/execution/operators/projection.rs index 39555af4ead..b555520b4de 100644 --- a/rust/worker/src/execution/operators/projection.rs +++ b/rust/worker/src/execution/operators/projection.rs @@ -184,12 +184,11 @@ mod tests { /// - Compacted: Upsert [1..=100] async fn setup_projection_input(offset_ids: Vec) -> ProjectionInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; - test_segment.populate_with_generator(100, &generator).await; + test_segment + .populate_with_generator(100, upsert_generator) + .await; ProjectionInput { - logs: generator.generate_chunk(81..=120), + logs: upsert_generator.generate_chunk(81..=120), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, offset_ids, diff --git a/rust/worker/src/execution/operators/spann_centers_search.rs b/rust/worker/src/execution/operators/spann_centers_search.rs index 064e5a9b762..53e8e37aba7 100644 --- a/rust/worker/src/execution/operators/spann_centers_search.rs +++ b/rust/worker/src/execution/operators/spann_centers_search.rs @@ -28,7 +28,7 @@ pub(crate) struct SpannCentersSearchOutput { } #[derive(Error, Debug)] -pub(crate) enum SpannCentersSearchError { +pub enum SpannCentersSearchError { #[error("Error creating spann segment reader")] SpannSegmentReaderCreationError, #[error("Error querying RNG")] diff --git a/rust/worker/src/execution/operators/spann_fetch_pl.rs b/rust/worker/src/execution/operators/spann_fetch_pl.rs index 16eaaee975e..d1a88ef29a7 100644 --- a/rust/worker/src/execution/operators/spann_fetch_pl.rs +++ b/rust/worker/src/execution/operators/spann_fetch_pl.rs @@ -22,7 +22,7 @@ pub(crate) struct SpannFetchPlOutput { } #[derive(Error, Debug)] -pub(crate) enum SpannFetchPlError { +pub enum SpannFetchPlError { #[error("Error creating spann segment reader")] SpannSegmentReaderCreationError, #[error("Error querying reader")] diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index 17d5b1249d7..83e9c813d1d 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -1,10 +1,8 @@ mod assignment; mod compactor; -mod config; mod memberlist; mod server; mod sysdb; -mod system; mod tracing; mod utils; @@ -15,9 +13,11 @@ use tokio::select; use tokio::signal::unix::{signal, SignalKind}; // Required for benchmark +pub mod config; pub mod execution; pub mod log; pub mod segment; +pub mod system; const CONFIG_PATH_ENV_VAR: &str = "CONFIG_PATH"; diff --git a/rust/worker/src/log/log.rs b/rust/worker/src/log/log.rs index 9681a99b7ca..c4e5a9d5532 100644 --- a/rust/worker/src/log/log.rs +++ b/rust/worker/src/log/log.rs @@ -40,7 +40,7 @@ pub(crate) struct CollectionRecord { } #[derive(Clone, Debug)] -pub(crate) enum Log { +pub enum Log { Grpc(GrpcLog), #[allow(dead_code)] InMemory(InMemoryLog), @@ -95,7 +95,7 @@ impl Log { } #[derive(Clone, Debug)] -pub(crate) struct GrpcLog { +pub struct GrpcLog { #[allow(clippy::type_complexity)] client: LogServiceClient< interceptor::InterceptedService< @@ -329,7 +329,7 @@ impl ChromaError for UpdateCollectionLogOffsetError { // This is used for testing only, it represents a log record that is stored in memory // internal to a mock log implementation #[derive(Clone)] -pub(crate) struct InternalLogRecord { +pub struct InternalLogRecord { pub(crate) collection_id: CollectionUuid, pub(crate) log_offset: i64, pub(crate) log_ts: i64, @@ -349,13 +349,12 @@ impl Debug for InternalLogRecord { // This is used for testing only #[derive(Clone, Debug)] -pub(crate) struct InMemoryLog { +pub struct InMemoryLog { collection_to_log: HashMap>, offsets: HashMap, } impl InMemoryLog { - #[cfg(test)] pub fn new() -> InMemoryLog { InMemoryLog { collection_to_log: HashMap::new(), @@ -450,3 +449,9 @@ impl InMemoryLog { Ok(()) } } + +impl Default for InMemoryLog { + fn default() -> Self { + Self::new() + } +} diff --git a/rust/worker/src/log/mod.rs b/rust/worker/src/log/mod.rs index 3af94d9aed6..7a9c1047ce5 100644 --- a/rust/worker/src/log/mod.rs +++ b/rust/worker/src/log/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod config; #[allow(clippy::module_inception)] -pub(crate) mod log; +pub mod log; #[allow(dead_code)] pub mod test; diff --git a/rust/worker/src/log/test.rs b/rust/worker/src/log/test.rs index ff7c12f8c0f..5c098cb27b0 100644 --- a/rust/worker/src/log/test.rs +++ b/rust/worker/src/log/test.rs @@ -8,35 +8,33 @@ use rand::{ pub const TEST_EMBEDDING_DIMENSION: usize = 6; -pub struct LogGenerator -where - G: Fn(usize) -> OperationRecord, -{ - pub generator: G, +pub trait LogGenerator { + fn generate_vec(&self, offsets: O) -> Vec + where + O: Iterator; + fn generate_chunk(&self, offsets: O) -> Chunk + where + O: Iterator, + { + Chunk::new(self.generate_vec(offsets).into()) + } } -impl LogGenerator +impl LogGenerator for G where G: Fn(usize) -> OperationRecord, { - pub fn generate_vec(&self, offsets: O) -> Vec + fn generate_vec(&self, offsets: O) -> Vec where O: Iterator, { offsets .map(|log_offset| LogRecord { log_offset: log_offset as i64, - record: (self.generator)(log_offset), + record: self(log_offset), }) .collect() } - - pub fn generate_chunk(&self, offsets: O) -> Chunk - where - O: Iterator, - { - Chunk::new(self.generate_vec(offsets).into()) - } } pub fn int_as_id(value: usize) -> String { diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index e22807897f9..53e51e481d5 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -27,7 +27,7 @@ pub struct HnswIndexParamsFromSegment { } #[derive(Clone)] -pub(crate) struct DistributedHNSWSegmentWriter { +pub struct DistributedHNSWSegmentWriter { index: HnswIndexRef, hnsw_index_provider: HnswIndexProvider, pub(crate) id: SegmentUuid, @@ -86,7 +86,7 @@ impl DistributedHNSWSegmentWriter { } } - pub(crate) async fn from_segment( + pub async fn from_segment( segment: &Segment, dimensionality: usize, hnsw_index_provider: HnswIndexProvider, @@ -96,7 +96,6 @@ impl DistributedHNSWSegmentWriter { // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this if !segment.file_path.is_empty() { - println!("Loading HNSW index from files"); // Check if its in the providers cache, if not load the index from the files let index_id = match &segment.file_path.get(HNSW_INDEX) { None => { @@ -272,9 +271,9 @@ impl SegmentFlusher for DistributedHNSWSegmentWriter { } #[derive(Clone)] -pub(crate) struct DistributedHNSWSegmentReader { +pub struct DistributedHNSWSegmentReader { index: HnswIndexRef, - pub(crate) id: SegmentUuid, + pub id: SegmentUuid, } impl Debug for DistributedHNSWSegmentReader { @@ -300,7 +299,6 @@ impl DistributedHNSWSegmentReader { // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this if !segment.file_path.is_empty() { - println!("Loading HNSW index from files"); // Check if its in the providers cache, if not load the index from the files let index_id = match &segment.file_path.get(HNSW_INDEX) { None => { diff --git a/rust/worker/src/segment/mod.rs b/rust/worker/src/segment/mod.rs index 14920256194..a3d8b648e26 100644 --- a/rust/worker/src/segment/mod.rs +++ b/rust/worker/src/segment/mod.rs @@ -1,11 +1,11 @@ pub(crate) mod config; -pub(crate) mod distributed_hnsw_segment; pub mod test; pub(crate) mod utils; pub(crate) use types::*; // Required for benchmark +pub mod distributed_hnsw_segment; pub mod metadata_segment; pub mod record_segment; pub mod spann_segment; diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index f69ad5e7f04..584596d47a4 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -948,17 +948,18 @@ mod tests { // The same record segment writer should be able to run concurrently on different threads without conflict #[test] fn test_max_offset_id_shuttle() { + let test_segment = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("Runtime creation should not fail") + .block_on(async { TestSegment::default() }); shuttle::check_random( - || { + move || { let log_partition_size = 100; let stack_size = 1 << 22; let thread_count = 4; - let log_generator = LogGenerator { - generator: upsert_generator, - }; let max_log_offset = thread_count * log_partition_size; - let logs = log_generator.generate_vec(1..=max_log_offset); - let test_segment = TestSegment::default(); + let logs = upsert_generator.generate_vec(1..=max_log_offset); let batches = logs .chunks(log_partition_size) diff --git a/rust/worker/src/segment/test.rs b/rust/worker/src/segment/test.rs index 25b9e8cbb90..6879058ae4e 100644 --- a/rust/worker/src/segment/test.rs +++ b/rust/worker/src/segment/test.rs @@ -1,20 +1,24 @@ use std::sync::atomic::AtomicU32; use chroma_blockstore::{provider::BlockfileProvider, test_arrow_blockfile_provider}; +use chroma_index::{hnsw_provider::HnswIndexProvider, test_hnsw_index_provider}; use chroma_types::{ - test_segment, Chunk, Collection, CollectionUuid, LogRecord, OperationRecord, Segment, + test_segment, Chunk, Collection, CollectionAndSegments, CollectionUuid, LogRecord, Segment, SegmentScope, }; use crate::log::test::{LogGenerator, TEST_EMBEDDING_DIMENSION}; use super::{ - materialize_logs, metadata_segment::MetadataSegmentWriter, record_segment::RecordSegmentWriter, - SegmentFlusher, SegmentWriter, + distributed_hnsw_segment::DistributedHNSWSegmentWriter, materialize_logs, + metadata_segment::MetadataSegmentWriter, record_segment::RecordSegmentWriter, SegmentFlusher, + SegmentWriter, }; +#[derive(Clone)] pub struct TestSegment { pub blockfile_provider: BlockfileProvider, + pub hnsw_provider: HnswIndexProvider, pub collection: Collection, pub metadata_segment: Segment, pub record_segment: Segment, @@ -22,8 +26,30 @@ pub struct TestSegment { } impl TestSegment { + pub fn new_with_dimension(dimension: usize) -> Self { + let collection_uuid = CollectionUuid::new(); + let collection = Collection { + collection_id: collection_uuid, + name: "Test Collection".to_string(), + metadata: None, + dimension: Some(dimension as i32), + tenant: "Test Tenant".to_string(), + database: String::new(), + log_position: 0, + version: 0, + }; + Self { + blockfile_provider: test_arrow_blockfile_provider(2 << 22), + hnsw_provider: test_hnsw_index_provider(), + collection, + metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA), + record_segment: test_segment(collection_uuid, SegmentScope::RECORD), + vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR), + } + } + // WARN: The size of the log chunk should not be too large - async fn compact_log(&mut self, logs: Chunk, next_offset: usize) { + pub async fn compact_log(&mut self, logs: Chunk, next_offset: usize) { let materialized_logs = materialize_logs( &None, &logs, @@ -57,22 +83,45 @@ impl TestSegment { .await .expect("Should be able to initiaize record writer."); record_writer - .apply_materialized_log_chunk(materialized_logs) + .apply_materialized_log_chunk(materialized_logs.clone()) .await .expect("Should be able to apply materialized log."); self.record_segment.file_path = record_writer .commit() .await - .expect("Should be able to commit metadata.") + .expect("Should be able to commit record.") .flush() .await - .expect("Should be able to flush metadata."); + .expect("Should be able to flush record."); + + let vector_writer = DistributedHNSWSegmentWriter::from_segment( + &self.vector_segment, + self.collection + .dimension + .expect("Collection dimension should be set") as usize, + self.hnsw_provider.clone(), + ) + .await + .expect("Should be able to initialize vector writer"); + + vector_writer + .apply_materialized_log_chunk(materialized_logs) + .await + .expect("Should be able to apply materialized log."); + + self.vector_segment.file_path = vector_writer + .commit() + .await + .expect("Should be able to commit vector.") + .flush() + .await + .expect("Should be able to flush vector."); } - pub async fn populate_with_generator(&mut self, size: usize, generator: &LogGenerator) + pub async fn populate_with_generator(&mut self, size: usize, generator: G) where - G: Fn(usize) -> OperationRecord, + G: LogGenerator, { let ids: Vec<_> = (1..=size).collect(); for chunk in ids.chunks(100) { @@ -90,23 +139,17 @@ impl TestSegment { impl Default for TestSegment { fn default() -> Self { - let collection_uuid = CollectionUuid::new(); - let collection = Collection { - collection_id: collection_uuid, - name: "Test Collection".to_string(), - metadata: None, - dimension: Some(TEST_EMBEDDING_DIMENSION as i32), - tenant: "Test Tenant".to_string(), - database: String::new(), - log_position: 0, - version: 0, - }; + Self::new_with_dimension(TEST_EMBEDDING_DIMENSION) + } +} + +impl From for CollectionAndSegments { + fn from(value: TestSegment) -> Self { Self { - blockfile_provider: test_arrow_blockfile_provider(2 << 22), - collection, - metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA), - record_segment: test_segment(collection_uuid, SegmentScope::RECORD), - vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR), + collection: value.collection, + metadata_segment: value.metadata_segment, + record_segment: value.record_segment, + vector_segment: value.vector_segment, } } } diff --git a/rust/worker/src/system/mod.rs b/rust/worker/src/system/mod.rs index 9c9b7117faa..7c92abb7e26 100644 --- a/rust/worker/src/system/mod.rs +++ b/rust/worker/src/system/mod.rs @@ -8,6 +8,6 @@ mod wrapped_message; // Re-export types pub(crate) use receiver::*; -pub(crate) use system::*; -pub(crate) use types::*; +pub use system::*; +pub use types::*; pub(crate) use wrapped_message::*; diff --git a/rust/worker/src/system/system.rs b/rust/worker/src/system/system.rs index b78b43b526e..bdd3f955b07 100644 --- a/rust/worker/src/system/system.rs +++ b/rust/worker/src/system/system.rs @@ -14,7 +14,7 @@ use tokio::{pin, select}; use tracing::{trace_span, Instrument, Span}; #[derive(Clone, Debug)] -pub(crate) struct System { +pub struct System { inner: Arc, } @@ -32,7 +32,7 @@ impl System { } } - pub(crate) fn start_component(&self, component: C) -> ComponentHandle + pub fn start_component(&self, component: C) -> ComponentHandle where C: Component + Send + 'static, { @@ -96,6 +96,12 @@ impl System { } } +impl Default for System { + fn default() -> Self { + Self::new() + } +} + async fn stream_loop(stream: S, ctx: &ComponentContext) where C: StreamHandler + Handler, diff --git a/rust/worker/src/system/types.rs b/rust/worker/src/system/types.rs index 339515d8a11..b96b05fb98e 100644 --- a/rust/worker/src/system/types.rs +++ b/rust/worker/src/system/types.rs @@ -22,7 +22,7 @@ pub(crate) enum ComponentState { } #[derive(Debug, PartialEq, Clone, Copy)] -pub(crate) enum ComponentRuntime { +pub enum ComponentRuntime { Inherit, Dedicated, } @@ -37,7 +37,7 @@ pub(crate) enum ComponentRuntime { /// - queue_size: The size of the queue to use for the component before it starts dropping messages /// - on_start: Called when the component is started #[async_trait] -pub(crate) trait Component: Send + Sized + Debug + 'static { +pub trait Component: Send + Sized + Debug + 'static { fn get_name() -> &'static str; fn queue_size(&self) -> usize; fn runtime() -> ComponentRuntime { @@ -180,7 +180,7 @@ impl Clone for ComponentSender { /// - join_handle: The join handle for the component, used to join on the component /// - sender: A channel to send messages to the component #[derive(Debug)] -pub(crate) struct ComponentHandle { +pub struct ComponentHandle { cancellation_token: tokio_util::sync::CancellationToken, state: Arc>, join_handle: Option, @@ -271,7 +271,7 @@ impl ComponentHandle { } /// The component context is passed to all Component Handler methods -pub(crate) struct ComponentContext +pub struct ComponentContext where C: Component + 'static, {