From 12e9d6d65428d50802fad587c50e9cd2197d5d92 Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Thu, 19 Dec 2024 17:22:11 -0800 Subject: [PATCH] [CLN] Implement Orchestrator trait and cleanups --- rust/worker/benches/get.rs | 5 +- rust/worker/benches/query.rs | 1 + .../src/compactor/compaction_manager.rs | 8 +- rust/worker/src/execution/dispatcher.rs | 6 +- rust/worker/src/execution/operator.rs | 15 +- .../src/execution/operators/count_records.rs | 2 +- .../src/execution/operators/fetch_log.rs | 2 +- rust/worker/src/execution/operators/filter.rs | 2 +- .../src/execution/operators/knn_hnsw.rs | 2 +- .../worker/src/execution/operators/knn_log.rs | 2 +- .../src/execution/operators/knn_merge.rs | 2 +- .../src/execution/operators/knn_projection.rs | 2 +- rust/worker/src/execution/operators/limit.rs | 2 +- .../execution/operators/prefetch_record.rs | 2 +- .../src/execution/operators/spann_bf_pl.rs | 2 +- .../operators/spann_centers_search.rs | 2 +- .../src/execution/operators/spann_fetch_pl.rs | 2 +- .../execution/operators/spann_knn_merge.rs | 2 +- .../src/execution/orchestration/common.rs | 32 -- .../src/execution/orchestration/compact.rs | 328 +++++-------- .../src/execution/orchestration/count.rs | 448 +++++------------- .../worker/src/execution/orchestration/get.rs | 116 ++--- .../worker/src/execution/orchestration/knn.rs | 142 ++---- .../src/execution/orchestration/knn_filter.rs | 142 +++--- .../worker/src/execution/orchestration/mod.rs | 2 +- .../execution/orchestration/orchestrator.rs | 111 +++++ .../src/execution/orchestration/spann_knn.rs | 166 ++----- rust/worker/src/execution/worker_thread.rs | 2 +- .../src/memberlist/memberlist_provider.rs | 2 +- rust/worker/src/segment/spann_segment.rs | 2 +- rust/worker/src/server.rs | 36 +- rust/worker/src/system/executor.rs | 2 +- rust/worker/src/system/scheduler.rs | 2 +- rust/worker/src/system/types.rs | 4 +- 34 files changed, 608 insertions(+), 990 deletions(-) delete mode 100644 rust/worker/src/execution/orchestration/common.rs create mode 100644 rust/worker/src/execution/orchestration/orchestrator.rs diff --git a/rust/worker/benches/get.rs b/rust/worker/benches/get.rs index 04290c38500..263f4d07f80 100644 --- a/rust/worker/benches/get.rs +++ b/rust/worker/benches/get.rs @@ -11,7 +11,10 @@ use load::{ }; use worker::{ config::RootConfig, - execution::{dispatcher::Dispatcher, orchestration::get::GetOrchestrator}, + execution::{ + dispatcher::Dispatcher, + orchestration::{get::GetOrchestrator, orchestrator::Orchestrator}, + }, segment::test::TestSegment, system::{ComponentHandle, System}, }; diff --git a/rust/worker/benches/query.rs b/rust/worker/benches/query.rs index 589cf27ba34..5fac8d05829 100644 --- a/rust/worker/benches/query.rs +++ b/rust/worker/benches/query.rs @@ -21,6 +21,7 @@ use worker::{ orchestration::{ knn::KnnOrchestrator, knn_filter::{KnnFilterOrchestrator, KnnFilterOutput}, + orchestrator::Orchestrator, }, }, segment::test::TestSegment, diff --git a/rust/worker/src/compactor/compaction_manager.rs b/rust/worker/src/compactor/compaction_manager.rs index 380cb5a978b..d69d29ec68d 100644 --- a/rust/worker/src/compactor/compaction_manager.rs +++ b/rust/worker/src/compactor/compaction_manager.rs @@ -4,6 +4,7 @@ use crate::compactor::types::CompactionJob; use crate::compactor::types::ScheduleMessage; use crate::config::CompactionServiceConfig; use crate::execution::dispatcher::Dispatcher; +use crate::execution::orchestration::orchestrator::Orchestrator; use crate::execution::orchestration::CompactOrchestrator; use crate::execution::orchestration::CompactionResponse; use crate::log::log::Log; @@ -115,7 +116,6 @@ impl CompactionManager { Some(ref system) => { let orchestrator = CompactOrchestrator::new( compaction_job.clone(), - system.clone(), compaction_job.collection_id, self.log.clone(), self.sysdb.clone(), @@ -129,14 +129,14 @@ impl CompactionManager { self.max_partition_size, ); - match orchestrator.run().await { + match orchestrator.run(system.clone()).await { Ok(result) => { tracing::info!("Compaction Job completed: {:?}", result); return Ok(result); } Err(e) => { tracing::error!("Compaction Job failed: {:?}", e); - return Err(e); + return Err(Box::new(e)); } } } @@ -280,7 +280,7 @@ impl Component for CompactionManager { self.compaction_manager_queue_size } - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { + async fn start(&mut self, ctx: &crate::system::ComponentContext) -> () { println!("Starting CompactionManager"); ctx.scheduler .schedule(ScheduleMessage {}, self.compaction_interval, ctx, || { diff --git a/rust/worker/src/execution/dispatcher.rs b/rust/worker/src/execution/dispatcher.rs index 5f29cc5761e..a7230f121f6 100644 --- a/rust/worker/src/execution/dispatcher.rs +++ b/rust/worker/src/execution/dispatcher.rs @@ -188,7 +188,7 @@ impl Component for Dispatcher { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { self.spawn_workers(&mut ctx.system.clone(), ctx.receiver()); } } @@ -314,7 +314,7 @@ mod tests { 1000 } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { // dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS); ctx.scheduler @@ -377,7 +377,7 @@ mod tests { 1000 } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { // dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS); ctx.scheduler diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index d82aaaec501..9bc0e631492 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -51,15 +51,6 @@ where } } -impl TaskError -where - Err: Debug + ChromaError + 'static, -{ - pub(super) fn boxed(self) -> Box { - Box::new(self) - } -} - /// A task result is a wrapper around the result of a task. /// It contains the task id for tracking purposes. #[derive(Debug)] @@ -94,12 +85,12 @@ where } /// A message type used by the dispatcher to send tasks to worker threads. -pub(crate) type TaskMessage = Box; +pub type TaskMessage = Box; /// A task wrapper is a trait that can be used to run a task. We use it to /// erase the I, O types from the Task struct so that tasks. #[async_trait] -pub(crate) trait TaskWrapper: Send + Debug { +pub trait TaskWrapper: Send + Debug { fn get_name(&self) -> &'static str; async fn run(&self); #[allow(dead_code)] @@ -264,7 +255,7 @@ mod tests { 1000 } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { let task = wrap(Box::new(MockOperator {}), (), ctx.receiver()); self.dispatcher.send(task, None).await.unwrap(); } diff --git a/rust/worker/src/execution/operators/count_records.rs b/rust/worker/src/execution/operators/count_records.rs index cf67764c2dd..ae468218985 100644 --- a/rust/worker/src/execution/operators/count_records.rs +++ b/rust/worker/src/execution/operators/count_records.rs @@ -2,12 +2,12 @@ use crate::{ execution::operator::Operator, segment::record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError}, }; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Chunk, LogRecord, Operation, Segment}; use std::collections::HashSet; use thiserror::Error; -use tonic::async_trait; #[derive(Debug)] pub(crate) struct CountRecordsOperator {} diff --git a/rust/worker/src/execution/operators/fetch_log.rs b/rust/worker/src/execution/operators/fetch_log.rs index 4feb421da3d..e49f06ae0e3 100644 --- a/rust/worker/src/execution/operators/fetch_log.rs +++ b/rust/worker/src/execution/operators/fetch_log.rs @@ -1,9 +1,9 @@ use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH}; +use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Chunk, CollectionUuid, LogRecord}; use thiserror::Error; -use tonic::async_trait; use tracing::trace; use crate::{ diff --git a/rust/worker/src/execution/operators/filter.rs b/rust/worker/src/execution/operators/filter.rs index 8296a82802c..408f7dc19f8 100644 --- a/rust/worker/src/execution/operators/filter.rs +++ b/rust/worker/src/execution/operators/filter.rs @@ -3,6 +3,7 @@ use std::{ ops::{BitAnd, BitOr, Bound}, }; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::metadata::types::MetadataIndexError; @@ -13,7 +14,6 @@ use chroma_types::{ }; use roaring::RoaringBitmap; use thiserror::Error; -use tonic::async_trait; use tracing::{trace, Instrument, Span}; use crate::{ diff --git a/rust/worker/src/execution/operators/knn_hnsw.rs b/rust/worker/src/execution/operators/knn_hnsw.rs index 2618ce7cd4d..4e80971d026 100644 --- a/rust/worker/src/execution/operators/knn_hnsw.rs +++ b/rust/worker/src/execution/operators/knn_hnsw.rs @@ -1,8 +1,8 @@ +use async_trait::async_trait; use chroma_distance::{normalize, DistanceFunction}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::SignedRoaringBitmap; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::Operator, segment::distributed_hnsw_segment::DistributedHNSWSegmentReader, diff --git a/rust/worker/src/execution/operators/knn_log.rs b/rust/worker/src/execution/operators/knn_log.rs index 1db07266b2a..ef333fdb44b 100644 --- a/rust/worker/src/execution/operators/knn_log.rs +++ b/rust/worker/src/execution/operators/knn_log.rs @@ -1,11 +1,11 @@ use std::collections::BinaryHeap; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::{normalize, DistanceFunction}; use chroma_error::ChromaError; use chroma_types::{MaterializedLogOperation, Segment, SignedRoaringBitmap}; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::Operator, diff --git a/rust/worker/src/execution/operators/knn_merge.rs b/rust/worker/src/execution/operators/knn_merge.rs index 545589a511d..fa3981328bb 100644 --- a/rust/worker/src/execution/operators/knn_merge.rs +++ b/rust/worker/src/execution/operators/knn_merge.rs @@ -1,4 +1,4 @@ -use tonic::async_trait; +use async_trait::async_trait; use crate::execution::operator::Operator; diff --git a/rust/worker/src/execution/operators/knn_projection.rs b/rust/worker/src/execution/operators/knn_projection.rs index 0b3cd9dc2fc..7883006320f 100644 --- a/rust/worker/src/execution/operators/knn_projection.rs +++ b/rust/worker/src/execution/operators/knn_projection.rs @@ -1,8 +1,8 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; use chroma_types::Segment; use thiserror::Error; -use tonic::async_trait; use tracing::trace; use crate::execution::{operator::Operator, operators::projection::ProjectionInput}; diff --git a/rust/worker/src/execution/operators/limit.rs b/rust/worker/src/execution/operators/limit.rs index 7fa04855943..7ae2e129ed7 100644 --- a/rust/worker/src/execution/operators/limit.rs +++ b/rust/worker/src/execution/operators/limit.rs @@ -1,11 +1,11 @@ use std::{cmp::Ordering, num::TryFromIntError, sync::atomic}; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Chunk, LogRecord, MaterializedLogOperation, Segment, SignedRoaringBitmap}; use roaring::RoaringBitmap; use thiserror::Error; -use tonic::async_trait; use tracing::{trace, Instrument, Span}; use crate::{ diff --git a/rust/worker/src/execution/operators/prefetch_record.rs b/rust/worker/src/execution/operators/prefetch_record.rs index 6b11fabe458..8a5deee479e 100644 --- a/rust/worker/src/execution/operators/prefetch_record.rs +++ b/rust/worker/src/execution/operators/prefetch_record.rs @@ -1,8 +1,8 @@ use std::collections::HashSet; +use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; use thiserror::Error; -use tonic::async_trait; use tracing::{trace, Instrument, Span}; use crate::{ diff --git a/rust/worker/src/execution/operators/spann_bf_pl.rs b/rust/worker/src/execution/operators/spann_bf_pl.rs index fdad8676eec..c274717f664 100644 --- a/rust/worker/src/execution/operators/spann_bf_pl.rs +++ b/rust/worker/src/execution/operators/spann_bf_pl.rs @@ -1,11 +1,11 @@ use std::collections::BinaryHeap; +use async_trait::async_trait; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::types::SpannPosting; use chroma_types::SignedRoaringBitmap; use thiserror::Error; -use tonic::async_trait; use crate::execution::operator::Operator; diff --git a/rust/worker/src/execution/operators/spann_centers_search.rs b/rust/worker/src/execution/operators/spann_centers_search.rs index 53e8e37aba7..6dc6d9edf06 100644 --- a/rust/worker/src/execution/operators/spann_centers_search.rs +++ b/rust/worker/src/execution/operators/spann_centers_search.rs @@ -1,8 +1,8 @@ +use async_trait::async_trait; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::utils::rng_query; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::Operator, diff --git a/rust/worker/src/execution/operators/spann_fetch_pl.rs b/rust/worker/src/execution/operators/spann_fetch_pl.rs index d1a88ef29a7..0732ef364c4 100644 --- a/rust/worker/src/execution/operators/spann_fetch_pl.rs +++ b/rust/worker/src/execution/operators/spann_fetch_pl.rs @@ -1,7 +1,7 @@ +use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::types::SpannPosting; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::{Operator, OperatorType}, diff --git a/rust/worker/src/execution/operators/spann_knn_merge.rs b/rust/worker/src/execution/operators/spann_knn_merge.rs index c1ac147a04e..85b6fd42320 100644 --- a/rust/worker/src/execution/operators/spann_knn_merge.rs +++ b/rust/worker/src/execution/operators/spann_knn_merge.rs @@ -1,6 +1,6 @@ use std::{cmp::Ordering, collections::BinaryHeap}; -use tonic::async_trait; +use async_trait::async_trait; use crate::execution::operator::Operator; diff --git a/rust/worker/src/execution/orchestration/common.rs b/rust/worker/src/execution/orchestration/common.rs deleted file mode 100644 index ff72803f3e8..00000000000 --- a/rust/worker/src/execution/orchestration/common.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::system::{Component, ComponentContext}; -use chroma_error::ChromaError; - -/// Terminate the orchestrator with an error -/// This function sends an error to the result channel and cancels the orchestrator -/// so it stops processing -/// # Arguments -/// * `result_channel` - The result channel to send the error to -/// * `error` - The error to send -/// * `ctx` - The component context -/// # Panics -/// This function panics if the result channel is not set -pub(super) fn terminate_with_error( - mut result_channel: Option>>, - error: E, - ctx: &ComponentContext, -) where - C: Component, - E: ChromaError, -{ - let result_channel = result_channel - .take() - .expect("Invariant violation. Result channel is not set."); - match result_channel.send(Err(error)) { - Ok(_) => (), - Err(_) => { - tracing::error!("Result channel dropped before sending error"); - } - } - // Cancel the orchestrator so it stops processing - ctx.cancellation_token.cancel(); -} diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index fa189c1008b..fb241c5f2fc 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -1,6 +1,9 @@ use super::super::operator::wrap; +use super::orchestrator::Orchestrator; use crate::compactor::CompactionJob; use crate::execution::dispatcher::Dispatcher; +use crate::execution::operator::TaskError; +use crate::execution::operator::TaskMessage; use crate::execution::operator::TaskResult; use crate::execution::operators::fetch_log::FetchLogError; use crate::execution::operators::fetch_log::FetchLogOperator; @@ -20,7 +23,6 @@ use crate::execution::operators::write_segments::WriteSegmentsInput; use crate::execution::operators::write_segments::WriteSegmentsOperator; use crate::execution::operators::write_segments::WriteSegmentsOperatorError; use crate::execution::operators::write_segments::WriteSegmentsOutput; -use crate::execution::orchestration::common::terminate_with_error; use crate::log::log::Log; use crate::segment::distributed_hnsw_segment::DistributedHNSWSegmentWriter; use crate::segment::metadata_segment::MetadataSegmentWriter; @@ -29,11 +31,10 @@ use crate::segment::record_segment::RecordSegmentWriter; use crate::sysdb::sysdb::GetCollectionsError; use crate::sysdb::sysdb::GetSegmentsError; use crate::sysdb::sysdb::SysDb; -use crate::system::Component; +use crate::system::ChannelError; +use crate::system::ComponentContext; use crate::system::ComponentHandle; use crate::system::Handler; -use crate::system::ReceiverForMessage; -use crate::system::System; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; @@ -46,7 +47,8 @@ use std::sync::atomic; use std::sync::atomic::AtomicU32; use std::sync::Arc; use thiserror::Error; -use tracing::Span; +use tokio::sync::oneshot::error::RecvError; +use tokio::sync::oneshot::Sender; use uuid::Uuid; /** The state of the orchestrator. @@ -67,20 +69,25 @@ understand. We can always add more abstraction later if we need it. #[derive(Debug)] enum ExecutionState { Pending, - PullLogs, Partition, Write, Flush, Register, } +#[derive(Clone, Debug)] +struct CompactWriters { + metadata: MetadataSegmentWriter<'static>, + record: RecordSegmentWriter, + vector: Box, +} + #[derive(Debug)] pub struct CompactOrchestrator { id: Uuid, compaction_job: CompactionJob, state: ExecutionState, // Component Execution - system: System, collection_id: CollectionUuid, // Dependencies log: Box, @@ -93,16 +100,11 @@ pub struct CompactOrchestrator { // Dispatcher dispatcher: ComponentHandle, // Shared writers - writers: Option<( - RecordSegmentWriter, - Box, - MetadataSegmentWriter<'static>, - )>, + writers: Option, // number of write segments tasks num_write_tasks: i32, // Result Channel - result_channel: - Option>>>, + result_channel: Option>>, // Next offset id next_offset_id: Arc, max_compaction_size: usize, @@ -140,11 +142,35 @@ impl ChromaError for GetSegmentWritersError { } #[derive(Error, Debug)] -enum CompactionError { - #[error("Task dispatch failed")] - DispatchFailure, - #[error("Result channel dropped")] - ResultChannelDropped, +pub enum CompactionError { + #[error("Panic running task: {0}")] + Panic(String), + #[error("FetchLog error: {0}")] + FetchLog(#[from] FetchLogError), + #[error("Partition error: {0}")] + Partition(#[from] PartitionError), + #[error("WriteSegments error: {0}")] + WriteSegments(#[from] WriteSegmentsOperatorError), + #[error("Regester error: {0}")] + Register(#[from] RegisterError), + #[error("Error sending message through channel: {0}")] + Channel(#[from] ChannelError), + #[error("Error receiving final result: {0}")] + Result(#[from] RecvError), + #[error("{0}")] + Generic(#[from] Box), +} + +impl From> for CompactionError +where + E: Into, +{ + fn from(value: TaskError) -> Self { + match value { + TaskError::Panic(e) => CompactionError::Panic(e.unwrap_or_default()), + TaskError::TaskFailed(e) => e.into(), + } + } } impl ChromaError for CompactionError { @@ -167,16 +193,13 @@ impl CompactOrchestrator { #[allow(clippy::too_many_arguments)] pub fn new( compaction_job: CompactionJob, - system: System, collection_id: CollectionUuid, log: Box, sysdb: Box, blockfile_provider: BlockfileProvider, hnsw_index_provider: HnswIndexProvider, dispatcher: ComponentHandle, - result_channel: Option< - tokio::sync::oneshot::Sender>>, - >, + result_channel: Option>>, record_segment: Option, next_offset_id: Arc, max_compaction_size: usize, @@ -186,7 +209,6 @@ impl CompactOrchestrator { id: Uuid::new_v4(), compaction_job, state: ExecutionState::Pending, - system, collection_id, log, sysdb, @@ -204,76 +226,38 @@ impl CompactOrchestrator { } } - async fn fetch_log( - &mut self, - self_address: Box>>, - ctx: &crate::system::ComponentContext, - ) { - self.state = ExecutionState::PullLogs; - let operator = FetchLogOperator { - log_client: self.log.clone(), - batch_size: 100, - // Here we do not need to be inclusive since the compaction job - // offset is the one after the last compaction offset - start_log_offset_id: self.compaction_job.offset as u32, - maximum_fetch_count: Some(self.max_compaction_size as u32), - collection_uuid: self.collection_id, - }; - let task = wrap(Box::new(operator), (), self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching pull logs for compaction {:?}", e); - terminate_with_error( - self.result_channel.take(), - Box::new(CompactionError::DispatchFailure), - ctx, - ); - } - } - } - async fn partition( &mut self, records: Chunk, - self_address: Box>>, + ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Partition; let operator = PartitionOperator::new(); tracing::info!("Sending N Records: {:?}", records.len()); println!("Sending N Records: {:?}", records.len()); let input = PartitionInput::new(records, self.max_partition_size); - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching partition for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e - ) - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } async fn write( &mut self, partitions: Vec>, - self_address: Box< - dyn ReceiverForMessage>, - >, ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Write; - if let Err(e) = self.init_segment_writers().await { - tracing::error!("Error creating writers for compaction {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); + let init_res = self.init_segment_writers().await; + if self.ok_or_terminate(init_res, ctx).is_none() { return; } let (record_segment_writer, hnsw_segment_writer, metadata_segment_writer) = match self.writers.clone() { - Some((rec, hnsw, mt)) => (Some(rec), Some(hnsw), Some(mt)), + Some(writers) => ( + Some(writers.record), + Some(writers.vector), + Some(writers.metadata), + ), None => (None, None, None), }; @@ -292,16 +276,8 @@ impl CompactOrchestrator { .clone(), self.next_offset_id.clone(), ); - let task = wrap(operator, input, self_address.clone()); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching writers for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e) - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } } @@ -310,7 +286,7 @@ impl CompactOrchestrator { record_segment_writer: RecordSegmentWriter, hnsw_segment_writer: Box, metadata_segment_writer: MetadataSegmentWriter<'static>, - self_address: Box>>>, + ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Flush; @@ -321,24 +297,15 @@ impl CompactOrchestrator { metadata_segment_writer, ); - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching flush to S3 for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e - ); - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } async fn register( &mut self, log_position: i64, segment_flush_info: Arc<[SegmentFlushInfo]>, - self_address: Box>>, + ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Register; let operator = RegisterOperator::new(); @@ -352,17 +319,8 @@ impl CompactOrchestrator { self.log.clone(), ); - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching register for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e - ); - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } async fn init_segment_writers(&mut self) -> Result<(), Box> { @@ -490,42 +448,52 @@ impl CompactOrchestrator { return Err(Box::new(GetSegmentWritersError::HnswSegmentWriterError)); } }; - self.writers = Some(( - record_segment_writer, - hnsw_segment_writer, - mt_segment_writer, - )) + self.writers = Some(CompactWriters { + metadata: mt_segment_writer, + record: record_segment_writer, + vector: hnsw_segment_writer, + }) } Ok(()) } - - pub(crate) async fn run(mut self) -> Result> { - println!("Running compaction job: {:?}", self.compaction_job); - let (tx, rx) = tokio::sync::oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = self.system.clone().start_component(self); - let result = rx.await; - handle.stop(); - result - .map_err(|_| Box::new(CompactionError::ResultChannelDropped) as Box)? - } } // ============== Component Implementation ============== #[async_trait] -impl Component for CompactOrchestrator { - fn get_name() -> &'static str { - "Compaction orchestrator" +impl Orchestrator for CompactOrchestrator { + type Output = CompactionResponse; + type Error = CompactionError; + + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - fn queue_size(&self) -> usize { - 1000 // TODO: make configurable + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap( + Box::new(FetchLogOperator { + log_client: self.log.clone(), + batch_size: 100, + // Here we do not need to be inclusive since the compaction job + // offset is the one after the last compaction offset + start_log_offset_id: self.compaction_job.offset as u32, + maximum_fetch_count: Some(self.max_compaction_size as u32), + collection_uuid: self.collection_id, + }), + (), + ctx.receiver(), + )] } - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { - self.fetch_log(ctx.receiver(), ctx).await; + fn set_result_channel(&mut self, sender: Sender>) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender> { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -539,13 +507,9 @@ impl Handler> for CompactOrchestrator message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - let records = match message { - Ok(result) => result, - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - return; - } + let records = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(recs) => recs, + None => todo!(), }; tracing::info!("Pulled Records: {:?}", records.len()); let final_record_pulled = records.get(records.len() - 1); @@ -553,7 +517,7 @@ impl Handler> for CompactOrchestrator Some(record) => { self.pulled_log_offset = Some(record.log_offset); tracing::info!("Pulled Logs Up To Offset: {:?}", self.pulled_log_offset); - self.partition(records, ctx.receiver()).await; + self.partition(records, ctx).await; } None => { tracing::error!( @@ -574,16 +538,11 @@ impl Handler> for CompactOrchestrato message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - let records = match message { - Ok(result) => result.records, - Err(e) => { - tracing::error!("Error partitioning records: {:?}", e); - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - return; - } + let records = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(recs) => recs.records, + None => todo!(), }; - self.write(records, ctx.receiver(), ctx).await; + self.write(records, ctx).await; } } @@ -596,33 +555,22 @@ impl Handler> for Co message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - let output = match message { - Ok(output) => { - self.num_write_tasks -= 1; - output - } - Err(e) => { - tracing::error!("Error writing segments: {:?}", e); - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; + self.num_write_tasks -= 1; if self.num_write_tasks == 0 { if let (Some(rec), Some(hnsw), Some(mt)) = ( output.record_segment_writer, output.hnsw_segment_writer, output.metadata_segment_writer, ) { - self.flush_s3(rec, hnsw, mt, ctx.receiver()).await; + self.flush_s3(rec, hnsw, mt, ctx).await; } else { // There is nothing to flush, proceed to register - self.register( - self.pulled_log_offset.unwrap(), - Arc::new([]), - ctx.receiver(), - ) - .await; + self.register(self.pulled_log_offset.unwrap(), Arc::new([]), ctx) + .await; } } } @@ -637,22 +585,16 @@ impl Handler>> for CompactOrchest message: TaskResult>, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - match message { - Ok(msg) => { - // Unwrap should be safe here as we are guaranteed to have a value by construction - self.register( - self.pulled_log_offset.unwrap(), - msg.segment_flush_info, - ctx.receiver(), - ) - .await; - } - Err(e) => { - tracing::error!("Error flushing to S3: {:?}", e); - terminate_with_error(self.result_channel.take(), e.boxed(), ctx); - } - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, + }; + self.register( + self.pulled_log_offset.unwrap(), + output.segment_flush_info, + ctx, + ) + .await; } } @@ -665,26 +607,16 @@ impl Handler> for CompactOrchestrator message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - // Return execution state to the compaction manager - let result_channel = self - .result_channel - .take() - .expect("Invariant violation. Result channel is not set."); - - match message { - Ok(_) => { - let response = CompactionResponse { + self.terminate_with_result( + message + .into_inner() + .map_err(|e| e.into()) + .map(|_| CompactionResponse { id: self.id, compaction_job: self.compaction_job.clone(), message: "Compaction Complete".to_string(), - }; - let _ = result_channel.send(Ok(response)); - } - Err(e) => { - tracing::error!("Error registering compaction: {:?}", e); - terminate_with_error(Some(result_channel), Box::new(e), ctx); - } - } + }), + ctx, + ); } } diff --git a/rust/worker/src/execution/orchestration/count.rs b/rust/worker/src/execution/orchestration/count.rs index 54cbde8690b..8d91359b83c 100644 --- a/rust/worker/src/execution/orchestration/count.rs +++ b/rust/worker/src/execution/orchestration/count.rs @@ -1,332 +1,133 @@ -use crate::execution::dispatcher::Dispatcher; -use crate::execution::operator::{wrap, TaskResult}; -use crate::execution::operators::count_records::{ - CountRecordsError, CountRecordsInput, CountRecordsOperator, CountRecordsOutput, -}; -use crate::execution::operators::fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}; -use crate::execution::orchestration::common::terminate_with_error; -use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError}; -use crate::system::{Component, ComponentContext, ComponentHandle, Handler}; -use crate::{log::log::Log, sysdb::sysdb::SysDb, system::System}; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Collection, CollectionUuid, Segment, SegmentType, SegmentUuid}; +use chroma_types::CollectionAndSegments; use thiserror::Error; -use tracing::Span; -use uuid::Uuid; +use tokio::sync::oneshot::{error::RecvError, Sender}; + +use crate::{ + execution::{ + dispatcher::Dispatcher, + operator::{wrap, TaskError, TaskMessage, TaskResult}, + operators::{ + count_records::{ + CountRecordsError, CountRecordsInput, CountRecordsOperator, CountRecordsOutput, + }, + fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, + }, + }, + system::{ChannelError, ComponentContext, ComponentHandle, Handler}, +}; -#[derive(Debug)] -pub(crate) struct CountQueryOrchestrator { - // Component Execution - system: System, - // Query state - metadata_segment_id: Uuid, - collection_id: CollectionUuid, - // State fetched or created for query execution - record_segment: Option, - collection: Option, - // Services - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - blockfile_provider: BlockfileProvider, - // Result channel - result_channel: Option>>>, - // Request version context - collection_version: u32, - log_position: u64, -} +use super::orchestrator::Orchestrator; #[derive(Error, Debug)] -enum CountQueryOrchestratorError { - #[error("Blockfile metadata segment with id: {0} not found")] - BlockfileMetadataSegmentNotFound(Uuid), - #[error("Get segments error: {0}")] - GetSegmentsError(#[from] GetSegmentsError), - #[error("Record segment not found for collection: {0}")] - RecordSegmentNotFound(CollectionUuid), - #[error("System Time Error")] - SystemTimeError(#[from] std::time::SystemTimeError), - #[error("Collection not found for id: {0}")] - CollectionNotFound(CollectionUuid), - #[error("Get collection error: {0}")] - GetCollectionError(#[from] GetCollectionsError), - #[error("Collection version mismatch")] - CollectionVersionMismatch, - #[error("Task dispatch failed")] - DispatchFailure, +pub enum CountError { + #[error("Error sending message through channel: {0}")] + Channel(#[from] ChannelError), + #[error("Error running Fetch Log Operator: {0}")] + FetchLog(#[from] FetchLogError), + #[error("Error running Count Record Operator: {0}")] + CountRecord(#[from] CountRecordsError), + #[error("Panic running task: {0}")] + Panic(String), + #[error("Error receiving final result: {0}")] + Result(#[from] RecvError), } -impl ChromaError for CountQueryOrchestratorError { +impl ChromaError for CountError { fn code(&self) -> ErrorCodes { match self { - CountQueryOrchestratorError::BlockfileMetadataSegmentNotFound(_) => { - ErrorCodes::NotFound - } - CountQueryOrchestratorError::GetSegmentsError(e) => e.code(), - CountQueryOrchestratorError::RecordSegmentNotFound(_) => ErrorCodes::NotFound, - CountQueryOrchestratorError::SystemTimeError(_) => ErrorCodes::Internal, - CountQueryOrchestratorError::CollectionNotFound(_) => ErrorCodes::NotFound, - CountQueryOrchestratorError::GetCollectionError(e) => e.code(), - CountQueryOrchestratorError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, - CountQueryOrchestratorError::DispatchFailure => ErrorCodes::Internal, + CountError::Channel(e) => e.code(), + CountError::FetchLog(e) => e.code(), + CountError::CountRecord(e) => e.code(), + CountError::Panic(_) => ErrorCodes::Aborted, + CountError::Result(_) => ErrorCodes::Internal, } } } -impl CountQueryOrchestrator { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - system: System, - metadata_segment_id: &Uuid, - collection_id: &CollectionUuid, - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - blockfile_provider: BlockfileProvider, - collection_version: u32, - log_position: u64, - ) -> Self { - Self { - system, - metadata_segment_id: *metadata_segment_id, - collection_id: *collection_id, - record_segment: None, - collection: None, - log, - sysdb, - dispatcher, - blockfile_provider, - result_channel: None, - collection_version, - log_position, - } - } - - async fn start(&mut self, ctx: &ComponentContext) { - println!("Starting Count Query Orchestrator"); - // Populate the orchestrator with the initial state - The Record Segment and the Collection - let metdata_segment = self - .get_metadata_segment_from_id( - self.sysdb.clone(), - &self.metadata_segment_id, - &self.collection_id, - ) - .await; - - let metadata_segment = match metdata_segment { - Ok(segment) => segment, - Err(e) => { - tracing::error!("Error getting metadata segment: {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let collection_id = metadata_segment.collection; - - let record_segment = self - .get_record_segment_from_collection_id(self.sysdb.clone(), &collection_id) - .await; - - let record_segment = match record_segment { - Ok(segment) => segment, - Err(e) => { - tracing::error!("Error getting record segment: {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let collection = match self - .get_collection_from_id(self.sysdb.clone(), &collection_id, ctx) - .await - { - Ok(collection) => collection, - Err(e) => { - tracing::error!("Error getting collection: {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - // If the collection version does not match the request version then we terminate with an error - if collection.version as u32 != self.collection_version { - terminate_with_error( - self.result_channel.take(), - Box::new(CountQueryOrchestratorError::CollectionVersionMismatch), - ctx, - ); - return; +impl From> for CountError +where + E: Into, +{ + fn from(value: TaskError) -> Self { + match value { + TaskError::Panic(e) => CountError::Panic(e.unwrap_or_default()), + TaskError::TaskFailed(e) => e.into(), } - - self.record_segment = Some(record_segment); - self.collection = Some(collection); - self.fetch_log(ctx).await; } +} - // shared - async fn fetch_log(&mut self, ctx: &ComponentContext) { - println!("Count query orchestrator pulling logs"); - - let collection = self - .collection - .as_ref() - .expect("Invariant violation. Collection is not set before pull logs state."); - - let operator = FetchLogOperator { - log_client: self.log.clone(), - batch_size: 100, - // The collection log position is inclusive, and we want to start from the next log. - // Note that we query using the incoming log position this is critical for correctness. - start_log_offset_id: self.log_position as u32 + 1, - maximum_fetch_count: None, - collection_uuid: collection.collection_id, - }; +type CountOutput = usize; +type CountResult = Result; - let task = wrap(Box::new(operator), (), ctx.receiver()); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - // Log an error - this implies the dispatcher was dropped somehow - // and is likely fatal - println!("Error sending Count Query task: {:?}", e); - terminate_with_error( - self.result_channel.take(), - Box::new(CountQueryOrchestratorError::DispatchFailure), - ctx, - ); - } - } - } +#[derive(Debug)] +pub struct CountOrchestrator { + // Orchestrator parameters + blockfile_provider: BlockfileProvider, + dispatcher: ComponentHandle, + queue: usize, - // shared - async fn get_metadata_segment_from_id( - &self, - mut sysdb: Box, - metadata_segment_id: &Uuid, - collection_id: &CollectionUuid, - ) -> Result> { - let segments = sysdb - .get_segments( - Some(SegmentUuid(*metadata_segment_id)), - None, - None, - *collection_id, - ) - .await; - let segment = match segments { - Ok(segments) => { - if segments.is_empty() { - return Err(Box::new( - CountQueryOrchestratorError::BlockfileMetadataSegmentNotFound( - *metadata_segment_id, - ), - )); - } - segments[0].clone() - } - Err(e) => { - return Err(Box::new(CountQueryOrchestratorError::GetSegmentsError(e))); - } - }; + // Collection and segments + collection_and_segments: CollectionAndSegments, - if segment.r#type != SegmentType::BlockfileMetadata { - return Err(Box::new( - CountQueryOrchestratorError::BlockfileMetadataSegmentNotFound(*metadata_segment_id), - )); - } - Ok(segment) - } + // Fetch logs + fetch_log: FetchLogOperator, - // shared - async fn get_record_segment_from_collection_id( - &self, - mut sysdb: Box, - collection_id: &CollectionUuid, - ) -> Result> { - let segments = sysdb - .get_segments( - None, - Some(SegmentType::BlockfileRecord.into()), - None, - *collection_id, - ) - .await; + // Result channel + result_channel: Option>>, +} - match segments { - Ok(segments) => { - if segments.is_empty() { - return Err(Box::new( - CountQueryOrchestratorError::RecordSegmentNotFound(*collection_id), - )); - } - // Unwrap is safe as we know at least one segment exists from - // the check above - Ok(segments.into_iter().next().unwrap()) - } - Err(e) => Err(Box::new(CountQueryOrchestratorError::GetSegmentsError(e))), +impl CountOrchestrator { + pub(crate) fn new( + blockfile_provider: BlockfileProvider, + dispatcher: ComponentHandle, + queue: usize, + collection_and_segments: CollectionAndSegments, + fetch_log: FetchLogOperator, + ) -> Self { + Self { + blockfile_provider, + dispatcher, + collection_and_segments, + queue, + fetch_log, + result_channel: None, } } +} - // shared - async fn get_collection_from_id( - &self, - mut sysdb: Box, - collection_id: &CollectionUuid, - _ctx: &ComponentContext, - ) -> Result> { - let collections = sysdb - .get_collections(Some(*collection_id), None, None, None) - .await; +#[async_trait] +impl Orchestrator for CountOrchestrator { + type Output = CountOutput; + type Error = CountError; - match collections { - Ok(collections) => { - if collections.is_empty() { - return Err(Box::new(CountQueryOrchestratorError::CollectionNotFound( - *collection_id, - ))); - } - // Unwrap is safe as we know at least one collection exists from - // the check above - Ok(collections.into_iter().next().unwrap()) - } - Err(e) => Err(Box::new(CountQueryOrchestratorError::GetCollectionError(e))), - } + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - /// Run the orchestrator and return the result. - /// # Note - /// Use this over spawning the component directly. This method will start the component and - /// wait for it to finish before returning the result. - pub(crate) async fn run(mut self) -> Result> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = self.system.clone().start_component(self); - let result = rx.await; - handle.stop(); - result.unwrap() + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver())] } -} -#[async_trait] -impl Component for CountQueryOrchestrator { - fn get_name() -> &'static str { - "Count Query Orchestrator" + fn queue_size(&self) -> usize { + self.queue } - fn queue_size(&self) -> usize { - 1000 // TODO: make this configurable + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) } - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { - self.start(ctx).await; + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } #[async_trait] -impl Handler> for CountQueryOrchestrator { +impl Handler> for CountOrchestrator { type Result = (); async fn handle( @@ -334,37 +135,25 @@ impl Handler> for CountQueryOrchestrat message: TaskResult, ctx: &ComponentContext, ) { - let message = message.into_inner(); - match message { - Ok(logs) => { - let operator = CountRecordsOperator::new(); - let input = CountRecordsInput::new( - self.record_segment - .as_ref() - .expect("Expect segment") - .clone(), - self.blockfile_provider.clone(), - logs, - ); - let msg = wrap(operator, input, ctx.receiver()); - match self.dispatcher.send(msg, None).await { - Ok(_) => (), - Err(e) => { - // Log an error - this implies the dispatcher was dropped somehow - // and is likely fatal - println!("Error sending Count Query task: {:?}", e); - } - } - } - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, + }; + let task = wrap( + CountRecordsOperator::new(), + CountRecordsInput::new( + self.collection_and_segments.record_segment.clone(), + self.blockfile_provider.clone(), + output, + ), + ctx.receiver(), + ); + self.send(task, ctx).await; } } #[async_trait] -impl Handler> for CountQueryOrchestrator { +impl Handler> for CountOrchestrator { type Result = (); async fn handle( @@ -372,23 +161,12 @@ impl Handler> for CountQueryOr message: TaskResult, ctx: &ComponentContext, ) { - let message = message.into_inner(); - let msg = match message { - Ok(m) => m, - Err(e) => { - return terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - }; - let channel = self - .result_channel - .take() - .expect("Expect channel to be present"); - match channel.send(Ok(msg.count)) { - Ok(_) => (), - Err(_) => { - // Log an error - this implied the listener was dropped - println!("[CountQueryOrchestrator] Result channel dropped before sending result"); - } - } + self.terminate_with_result( + message + .into_inner() + .map_err(|e| e.into()) + .map(|output| output.count), + ctx, + ); } } diff --git a/rust/worker/src/execution/orchestration/get.rs b/rust/worker/src/execution/orchestration/get.rs index 9eaf2995e8a..2b955e3fc78 100644 --- a/rust/worker/src/execution/orchestration/get.rs +++ b/rust/worker/src/execution/orchestration/get.rs @@ -1,15 +1,14 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::CollectionAndSegments; use thiserror::Error; -use tokio::sync::oneshot::{self, error::RecvError, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::{error::RecvError, Sender}; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskError, TaskResult}, + operator::{wrap, TaskError, TaskMessage, TaskResult}, operators::{ fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, filter::{FilterError, FilterInput, FilterOperator, FilterOutput}, @@ -17,11 +16,12 @@ use crate::{ prefetch_record::{PrefetchRecordError, PrefetchRecordOperator, PrefetchRecordOutput}, projection::{ProjectionError, ProjectionInput, ProjectionOperator, ProjectionOutput}, }, - orchestration::common::terminate_with_error, }, - system::{ChannelError, Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ChannelError, ComponentContext, ComponentHandle, Handler}, }; +use super::orchestrator::Orchestrator; + #[derive(Error, Debug)] pub enum GetError { #[error("Error sending message through channel: {0}")] @@ -166,42 +166,33 @@ impl GetOrchestrator { result_channel: None, } } +} - pub async fn run(mut self, system: System) -> GetResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } +#[async_trait] +impl Orchestrator for GetOrchestrator { + type Output = GetOutput; + type Error = GetError; - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let get_err = err.into(); - tracing::error!("Error running orchestrator: {}", &get_err); - terminate_with_error(self.result_channel.take(), get_err, ctx); + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } -} -#[async_trait] -impl Component for GetOrchestrator { - fn get_name() -> &'static str { - "Get Orchestrator" + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver())] } fn queue_size(&self) -> usize { self.queue } - async fn on_start(&mut self, ctx: &ComponentContext) { - let task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - return; - } + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -214,12 +205,9 @@ impl Handler> for GetOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.fetched_logs = Some(output.clone()); @@ -234,9 +222,7 @@ impl Handler> for GetOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } @@ -249,12 +235,9 @@ impl Handler> for GetOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; let task = wrap( Box::new(self.limit.clone()), @@ -271,9 +254,7 @@ impl Handler> for GetOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } @@ -286,12 +267,9 @@ impl Handler> for GetOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; let input = ProjectionInput { @@ -311,18 +289,13 @@ impl Handler> for GetOrchestrator { input.clone(), ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); + + if !self.send(prefetch_task, ctx).await { + return; } let task = wrap(Box::new(self.projection.clone()), input, ctx.receiver()); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } @@ -348,17 +321,6 @@ impl Handler> for GetOrchestrator message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - if let Some(chan) = self.result_channel.take() { - if chan.send(Ok(output)).is_err() { - tracing::error!("Error sending final result"); - }; - } + self.terminate_with_result(message.into_inner().map_err(|e| e.into()), ctx); } } diff --git a/rust/worker/src/execution/orchestration/knn.rs b/rust/worker/src/execution/orchestration/knn.rs index c57f6d9605e..2ffcd26ee1a 100644 --- a/rust/worker/src/execution/orchestration/knn.rs +++ b/rust/worker/src/execution/orchestration/knn.rs @@ -1,12 +1,11 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; -use tokio::sync::oneshot::{self, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::Sender; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskResult}, + operator::{wrap, TaskMessage, TaskResult}, operators::{ knn::{KnnOperator, RecordDistance}, knn_hnsw::{KnnHnswError, KnnHnswInput, KnnHnswOutput}, @@ -20,12 +19,14 @@ use crate::{ PrefetchRecordOutput, }, }, - orchestration::common::terminate_with_error, }, - system::{Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ComponentContext, ComponentHandle, Handler}, }; -use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; +use super::{ + knn_filter::{KnnError, KnnFilterOutput, KnnOutput, KnnResult}, + orchestrator::Orchestrator, +}; /// The `KnnOrchestrator` finds the nearest neighbor of a target embedding given the search domain. /// When used together with `KnnFilterOrchestrator`, they evaluate a `.query(...)` query @@ -146,6 +147,11 @@ impl KnnOrchestrator { knn_projection: KnnProjectionOperator, ) -> Self { let fetch = knn.fetch; + let knn_segment_distances = if knn_filter_output.hnsw_reader.is_none() { + Some(Vec::new()) + } else { + None + }; Self { blockfile_provider, dispatcher, @@ -153,31 +159,13 @@ impl KnnOrchestrator { knn_filter_output, knn, knn_log_distances: None, - knn_segment_distances: None, + knn_segment_distances, merge: KnnMergeOperator { fetch }, knn_projection, result_channel: None, } } - pub async fn run(mut self, system: System) -> KnnResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } - - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let knn_err = err.into(); - tracing::error!("Error running orchestrator: {}", &knn_err); - terminate_with_error(self.result_channel.take(), knn_err, ctx); - } - async fn try_start_knn_merge_operator(&mut self, ctx: &ComponentContext) { if let (Some(log_distances), Some(segment_distances)) = ( self.knn_log_distances.as_ref(), @@ -191,24 +179,23 @@ impl KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } } #[async_trait] -impl Component for KnnOrchestrator { - fn get_name() -> &'static str { - "Knn Orchestrator" - } +impl Orchestrator for KnnOrchestrator { + type Output = KnnOutput; + type Error = KnnError; - fn queue_size(&self) -> usize { - self.queue + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - async fn on_start(&mut self, ctx: &ComponentContext) { + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + let mut tasks = Vec::new(); + let knn_log_task = wrap( Box::new(self.knn.clone()), KnnLogInput { @@ -220,14 +207,7 @@ impl Component for KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(knn_log_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - return; - } + tasks.push(knn_log_task); if let Some(hnsw_reader) = self.knn_filter_output.hnsw_reader.as_ref().cloned() { let knn_segment_task = wrap( @@ -243,17 +223,24 @@ impl Component for KnnOrchestrator { }, ctx.receiver(), ); - - if let Err(err) = self - .dispatcher - .send(knn_segment_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } - } else { - self.knn_segment_distances = Some(Vec::new()) + tasks.push(knn_segment_task); } + + tasks + } + + fn queue_size(&self) -> usize { + self.queue + } + + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -266,12 +253,9 @@ impl Handler> for KnnOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.knn_log_distances = Some(output.record_distances); self.try_start_knn_merge_operator(ctx).await; @@ -287,12 +271,9 @@ impl Handler> for KnnOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.knn_segment_distances = Some(output.record_distances); self.try_start_knn_merge_operator(ctx).await; @@ -327,13 +308,7 @@ impl Handler> for KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(prefetch_task, ctx).await; let projection_task = wrap( Box::new(self.knn_projection.clone()), @@ -345,13 +320,7 @@ impl Handler> for KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(projection_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(projection_task, ctx).await; } } @@ -377,17 +346,6 @@ impl Handler> for KnnOrchest message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - if let Some(chan) = self.result_channel.take() { - if chan.send(Ok(output)).is_err() { - tracing::error!("Error sending final result"); - }; - } + self.terminate_with_result(message.into_inner().map_err(|e| e.into()), ctx); } } diff --git a/rust/worker/src/execution/orchestration/knn_filter.rs b/rust/worker/src/execution/orchestration/knn_filter.rs index 4ccc549d789..54632c3755b 100644 --- a/rust/worker/src/execution/orchestration/knn_filter.rs +++ b/rust/worker/src/execution/orchestration/knn_filter.rs @@ -1,17 +1,16 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; use chroma_types::{CollectionAndSegments, Segment}; use thiserror::Error; -use tokio::sync::oneshot::{self, error::RecvError, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::{error::RecvError, Sender}; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskError, TaskResult}, + operator::{wrap, TaskError, TaskMessage, TaskResult}, operators::{ fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, filter::{FilterError, FilterInput, FilterOperator, FilterOutput}, @@ -22,7 +21,6 @@ use crate::{ spann_centers_search::SpannCentersSearchError, spann_fetch_pl::SpannFetchPlError, }, - orchestration::common::terminate_with_error, }, segment::{ distributed_hnsw_segment::{ @@ -30,9 +28,11 @@ use crate::{ }, utils::distance_function_from_segment, }, - system::{ChannelError, Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ChannelError, ComponentContext, ComponentHandle, Handler}, }; +use super::orchestrator::Orchestrator; + #[derive(Error, Debug)] pub enum KnnError { #[error("Error sending message through channel: {0}")] @@ -189,42 +189,33 @@ impl KnnFilterOrchestrator { result_channel: None, } } +} - pub async fn run(mut self, system: System) -> KnnFilterResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } +#[async_trait] +impl Orchestrator for KnnFilterOrchestrator { + type Output = KnnFilterOutput; + type Error = KnnError; - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let knn_err = err.into(); - tracing::error!("Error running orchestrator: {}", &knn_err); - terminate_with_error(self.result_channel.take(), knn_err, ctx); + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } -} -#[async_trait] -impl Component for KnnFilterOrchestrator { - fn get_name() -> &'static str { - "Knn Filter Orchestrator" + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver())] } fn queue_size(&self) -> usize { self.queue } - async fn on_start(&mut self, ctx: &ComponentContext) { - let task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - return; - } + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -237,12 +228,9 @@ impl Handler> for KnnFilterOrchestrato message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.fetched_logs = Some(output.clone()); @@ -257,9 +245,7 @@ impl Handler> for KnnFilterOrchestrato }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } @@ -272,28 +258,28 @@ impl Handler> for KnnFilterOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; - let collection_dimension = match self.collection_and_segments.collection.dimension { - Some(dimension) => dimension as u32, - None => { - self.terminate_with_error(ctx, KnnError::NoCollectionDimension); - return; - } + let collection_dimension = match self.ok_or_terminate( + self.collection_and_segments + .collection + .dimension + .ok_or(KnnError::NoCollectionDimension), + ctx, + ) { + Some(dim) => dim as u32, + None => return, + }; + let distance_function = match self.ok_or_terminate( + distance_function_from_segment(&self.collection_and_segments.vector_segment) + .map_err(|_| KnnError::InvalidDistanceFunction), + ctx, + ) { + Some(distance_function) => distance_function, + None => return, }; - let distance_function = - match distance_function_from_segment(&self.collection_and_segments.vector_segment) { - Ok(distance_function) => distance_function, - Err(_) => { - self.terminate_with_error(ctx, KnnError::InvalidDistanceFunction); - return; - } - }; let hnsw_reader = match DistributedHNSWSegmentReader::from_segment( &self.collection_and_segments.vector_segment, collection_dimension as usize, @@ -307,29 +293,23 @@ impl Handler> for KnnFilterOrchestrator { } Err(err) => { - self.terminate_with_error(ctx, *err); + self.terminate_with_result(Err((*err).into()), ctx); return; } }; - if let Some(chan) = self.result_channel.take() { - if chan - .send(Ok(KnnFilterOutput { - logs: self - .fetched_logs - .take() - .expect("FetchLogOperator should have finished already"), - distance_function, - filter_output: output, - hnsw_reader, - record_segment: self.collection_and_segments.record_segment.clone(), - vector_segment: self.collection_and_segments.vector_segment.clone(), - dimension: collection_dimension as usize, - })) - .is_err() - { - tracing::error!("Error sending final result"); - }; - } + let output = KnnFilterOutput { + logs: self + .fetched_logs + .take() + .expect("FetchLogOperator should have finished already"), + distance_function, + filter_output: output, + hnsw_reader, + record_segment: self.collection_and_segments.record_segment.clone(), + vector_segment: self.collection_and_segments.vector_segment.clone(), + dimension: collection_dimension as usize, + }; + self.terminate_with_result(Ok(output), ctx); } } diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index d9b83d6e48a..58a9c0eb942 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,4 +1,3 @@ -mod common; mod compact; mod count; mod spann_knn; @@ -8,3 +7,4 @@ pub(crate) use count::*; pub mod get; pub mod knn; pub mod knn_filter; +pub mod orchestrator; diff --git a/rust/worker/src/execution/orchestration/orchestrator.rs b/rust/worker/src/execution/orchestration/orchestrator.rs new file mode 100644 index 00000000000..ddc96f93c26 --- /dev/null +++ b/rust/worker/src/execution/orchestration/orchestrator.rs @@ -0,0 +1,111 @@ +use core::fmt::Debug; +use std::any::type_name; + +use async_trait::async_trait; +use chroma_error::ChromaError; +use tokio::sync::oneshot::{self, error::RecvError, Sender}; +use tracing::Span; + +use crate::{ + execution::{dispatcher::Dispatcher, operator::TaskMessage}, + system::{ChannelError, Component, ComponentContext, ComponentHandle, System}, +}; + +#[async_trait] +pub trait Orchestrator: Debug + Send + Sized + 'static { + type Output: Send; + type Error: ChromaError + From + From; + + /// Returns the handle of the dispatcher + fn dispatcher(&self) -> ComponentHandle; + + /// Returns a vector of starting tasks that should be run in sequence + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec; + + fn name() -> &'static str { + type_name::() + } + + fn queue_size(&self) -> usize { + 1000 + } + + /// Runs the orchestrator in a system and returns the result + async fn run(mut self, system: System) -> Result { + let (tx, rx) = oneshot::channel(); + self.set_result_channel(tx); + let mut handle = system.start_component(self); + let res = rx.await; + handle.stop(); + res? + } + + /// Sends a task to the dispatcher and return whether the task is successfully sent + async fn send(&mut self, task: TaskMessage, ctx: &ComponentContext) -> bool { + let res = self.dispatcher().send(task, Some(Span::current())).await; + self.ok_or_terminate(res, ctx).is_some() + } + + /// Sets the result channel of the orchestrator + fn set_result_channel(&mut self, sender: Sender>); + + /// Takes the result channel of the orchestrator. The channel should have been set when this is invoked + fn take_result_channel(&mut self) -> Sender>; + + /// Terminate the orchestrator with a result + fn terminate_with_result( + &mut self, + res: Result, + ctx: &ComponentContext, + ) { + let cancel = if let Err(err) = &res { + tracing::error!("Error running {}: {}", Self::name(), err); + true + } else { + false + }; + + let channel = self.take_result_channel(); + if channel.send(res).is_err() { + tracing::error!("Error sending result for {}", Self::name()); + }; + + if cancel { + ctx.cancellation_token.cancel(); + } + } + + /// Terminate the orchestrator if the result is an error. Returns the output if any. + fn ok_or_terminate>( + &mut self, + res: Result, + ctx: &ComponentContext, + ) -> Option { + match res { + Ok(output) => Some(output), + Err(error) => { + self.terminate_with_result(Err(error.into()), ctx); + None + } + } + } +} + +#[async_trait] +impl Component for O { + fn get_name() -> &'static str { + Self::name() + } + + fn queue_size(&self) -> usize { + self.queue_size() + } + + async fn start(&mut self, ctx: &ComponentContext) { + for task in self.initial_tasks(ctx) { + if !self.send(task, ctx).await { + break; + } + } + } +} diff --git a/rust/worker/src/execution/orchestration/spann_knn.rs b/rust/worker/src/execution/orchestration/spann_knn.rs index da4422be14b..a190055b800 100644 --- a/rust/worker/src/execution/orchestration/spann_knn.rs +++ b/rust/worker/src/execution/orchestration/spann_knn.rs @@ -1,14 +1,13 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::{normalize, DistanceFunction}; use chroma_index::hnsw_provider::HnswIndexProvider; -use tokio::sync::oneshot::{self, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::Sender; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskResult}, + operator::{wrap, TaskMessage, TaskResult}, operators::{ knn::{KnnOperator, RecordDistance}, knn_log::{KnnLogError, KnnLogInput, KnnLogOutput}, @@ -31,13 +30,15 @@ use crate::{ SpannKnnMergeError, SpannKnnMergeInput, SpannKnnMergeOperator, SpannKnnMergeOutput, }, }, - orchestration::common::terminate_with_error, }, segment::spann_segment::SpannSegmentReaderContext, - system::{Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ComponentContext, ComponentHandle, Handler}, }; -use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; +use super::{ + knn_filter::{KnnError, KnnFilterOutput, KnnOutput, KnnResult}, + orchestrator::Orchestrator, +}; // TODO(Sanket): Make these configurable. const RNG_FACTOR: f32 = 1.0; @@ -127,24 +128,6 @@ impl SpannKnnOrchestrator { } } - pub async fn run(mut self, system: System) -> KnnResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } - - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let knn_err = err.into(); - tracing::error!("Error running orchestrator: {}", &knn_err); - terminate_with_error(self.result_channel.take(), knn_err, ctx); - } - async fn try_start_knn_merge_operator(&mut self, ctx: &ComponentContext) { if self.heads_searched && self.num_outstanding_bf_pl == 0 { // This is safe because self.records is only used once and that is during merge. @@ -155,24 +138,23 @@ impl SpannKnnOrchestrator { SpannKnnMergeInput { records }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } } #[async_trait] -impl Component for SpannKnnOrchestrator { - fn get_name() -> &'static str { - "Spann Knn Orchestrator" - } +impl Orchestrator for SpannKnnOrchestrator { + type Output = KnnOutput; + type Error = KnnError; - fn queue_size(&self) -> usize { - self.queue + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - async fn on_start(&mut self, ctx: &ComponentContext) { + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + let mut tasks = Vec::new(); + let knn_log_task = wrap( Box::new(self.log_knn.clone()), KnnLogInput { @@ -184,16 +166,8 @@ impl Component for SpannKnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(knn_log_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - return; - } + tasks.push(knn_log_task); - // Invoke Head search operator. let reader_context = SpannSegmentReaderContext { segment: self.knn_filter_output.vector_segment.clone(), blockfile_provider: self.blockfile_provider.clone(), @@ -212,14 +186,23 @@ impl Component for SpannKnnOrchestrator { }, ctx.receiver(), ); + tasks.push(head_search_task); - if let Err(err) = self - .dispatcher - .send(head_search_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + tasks + } + + fn queue_size(&self) -> usize { + self.queue + } + + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -232,12 +215,9 @@ impl Handler> for SpannKnnOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.records.push(output.record_distances); self.try_start_knn_merge_operator(ctx).await; @@ -255,12 +235,9 @@ impl Handler> message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; // Set state that is used for tracking when we are ready for merging. self.heads_searched = true; @@ -283,13 +260,7 @@ impl Handler> ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(fetch_pl_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(fetch_pl_task, ctx).await; } } } @@ -303,12 +274,9 @@ impl Handler> for SpannKnnOrch message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; // Spawn brute force posting list task. let bf_pl_task = wrap( @@ -327,13 +295,7 @@ impl Handler> for SpannKnnOrch ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(bf_pl_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(bf_pl_task, ctx).await; } } @@ -346,12 +308,9 @@ impl Handler> for SpannKnnOrchestrat message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; // Update state tracking for merging. self.num_outstanding_bf_pl -= 1; @@ -389,13 +348,7 @@ impl Handler> for SpannKnnOr }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(prefetch_task, ctx).await; let projection_task = wrap( Box::new(self.knn_projection.clone()), @@ -407,13 +360,7 @@ impl Handler> for SpannKnnOr }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(projection_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(projection_task, ctx).await; } } @@ -439,17 +386,6 @@ impl Handler> for SpannKnnOr message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - if let Some(chan) = self.result_channel.take() { - if chan.send(Ok(output)).is_err() { - tracing::error!("Error sending final result"); - }; - } + self.terminate_with_result(message.into_inner().map_err(|e| e.into()), ctx); } } diff --git a/rust/worker/src/execution/worker_thread.rs b/rust/worker/src/execution/worker_thread.rs index 9a968980247..da5c54d59e8 100644 --- a/rust/worker/src/execution/worker_thread.rs +++ b/rust/worker/src/execution/worker_thread.rs @@ -45,7 +45,7 @@ impl Component for WorkerThread { ComponentRuntime::Dedicated } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { let req = TaskRequestMessage::new(ctx.receiver()); let _req = self.dispatcher.send(req, None).await; // TODO: what to do with resp? diff --git a/rust/worker/src/memberlist/memberlist_provider.rs b/rust/worker/src/memberlist/memberlist_provider.rs index c8462a6cc82..0857ebc8f38 100644 --- a/rust/worker/src/memberlist/memberlist_provider.rs +++ b/rust/worker/src/memberlist/memberlist_provider.rs @@ -183,7 +183,7 @@ impl Component for CustomResourceMemberlistProvider { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { self.connect_to_kube_stream(ctx); } } diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index be5e7cbf2a7..c4d207096e7 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunctionError; use chroma_error::{ChromaError, ErrorCodes}; @@ -11,7 +12,6 @@ use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWri use chroma_types::SegmentUuid; use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; -use tonic::async_trait; use uuid::Uuid; use super::{ diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 6e68734be5d..fdc0c093b30 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -1,5 +1,6 @@ use std::iter::once; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_config::Configurable; use chroma_error::ChromaError; @@ -23,7 +24,7 @@ use crate::{ operators::{fetch_log::FetchLogOperator, knn_projection::KnnProjectionOperator}, orchestration::{ get::GetOrchestrator, knn::KnnOrchestrator, knn_filter::KnnFilterOrchestrator, - CountQueryOrchestrator, + orchestrator::Orchestrator, CountOrchestrator, }, }, log::log::Log, @@ -41,13 +42,13 @@ pub struct WorkerServer { dispatcher: Option>, // Service dependencies log: Box, - sysdb: Box, + _sysdb: Box, hnsw_index_provider: HnswIndexProvider, blockfile_provider: BlockfileProvider, port: u16, } -#[async_trait::async_trait] +#[async_trait] impl Configurable for WorkerServer { async fn try_from_config(config: &QueryServiceConfig) -> Result> { let sysdb_config = &config.sysdb; @@ -85,7 +86,7 @@ impl Configurable for WorkerServer { Ok(WorkerServer { dispatcher: None, system: None, - sysdb, + _sysdb: sysdb, log, hnsw_index_provider, blockfile_provider, @@ -153,22 +154,19 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let collection_and_segments = CollectionAndSegments::try_from(scan)?; - let collection = &collection_and_segments.collection; + let collection_and_segments = scan.try_into()?; + let fetch_log = self.fetch_log(&collection_and_segments); - let count_orchestrator = CountQueryOrchestrator::new( - self.clone_system()?, - &collection_and_segments.metadata_segment.id.0, - &collection.collection_id, - self.log.clone(), - self.sysdb.clone(), - self.clone_dispatcher()?, + let count_orchestrator = CountOrchestrator::new( self.blockfile_provider.clone(), - collection.version as u32, - collection.log_position as u64, + self.clone_dispatcher()?, + // TODO: Make this configurable + 1000, + collection_and_segments, + fetch_log, ); - match count_orchestrator.run().await { + match count_orchestrator.run(self.clone_system()?).await { Ok(count) => Ok(Response::new(CountResult { count: count as u32, })), @@ -321,7 +319,7 @@ impl WorkerServer { } } -#[tonic::async_trait] +#[async_trait] impl QueryExecutor for WorkerServer { async fn count(&self, count: Request) -> Result, Status> { // Note: We cannot write a middleware that instruments every service rpc @@ -364,7 +362,7 @@ impl QueryExecutor for WorkerServer { } #[cfg(debug_assertions)] -#[tonic::async_trait] +#[async_trait] impl chroma_proto::debug_server::Debug for WorkerServer { async fn get_info( &self, @@ -420,7 +418,7 @@ mod tests { let mut server = WorkerServer { dispatcher: None, system: None, - sysdb: Box::new(SysDb::Test(sysdb)), + _sysdb: Box::new(SysDb::Test(sysdb)), log: Box::new(Log::InMemory(log)), hnsw_index_provider: test_hnsw_index_provider(), blockfile_provider: segments.blockfile_provider, diff --git a/rust/worker/src/system/executor.rs b/rust/worker/src/system/executor.rs index 2f064a8067e..09779e45340 100644 --- a/rust/worker/src/system/executor.rs +++ b/rust/worker/src/system/executor.rs @@ -53,7 +53,7 @@ where mut channel: tokio::sync::mpsc::Receiver>, ) { self.handler - .on_start(&ComponentContext { + .start(&ComponentContext { system: self.inner.system.clone(), sender: self.inner.sender.clone(), cancellation_token: self.inner.cancellation_token.clone(), diff --git a/rust/worker/src/system/scheduler.rs b/rust/worker/src/system/scheduler.rs index 29428da0844..725dca0c671 100644 --- a/rust/worker/src/system/scheduler.rs +++ b/rust/worker/src/system/scheduler.rs @@ -204,7 +204,7 @@ mod tests { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) -> () { + async fn start(&mut self, ctx: &ComponentContext) -> () { let duration = Duration::from_millis(100); ctx.scheduler .schedule(ScheduleMessage {}, duration, ctx, || None); diff --git a/rust/worker/src/system/types.rs b/rust/worker/src/system/types.rs index b96b05fb98e..1ff3e9365f7 100644 --- a/rust/worker/src/system/types.rs +++ b/rust/worker/src/system/types.rs @@ -43,7 +43,7 @@ pub trait Component: Send + Sized + Debug + 'static { fn runtime() -> ComponentRuntime { ComponentRuntime::Inherit } - async fn on_start(&mut self, _ctx: &ComponentContext) -> () {} + async fn start(&mut self, _ctx: &ComponentContext) -> () {} } /// A handler is a component that can process messages of a given type. @@ -346,7 +346,7 @@ mod tests { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) -> () { + async fn start(&mut self, ctx: &ComponentContext) -> () { let test_stream = stream::iter(vec![1, 2, 3]); self.register_stream(test_stream, ctx); }