From f9658bbcf243151d02ae25b1f5eb13cafb9a06ec Mon Sep 17 00:00:00 2001 From: Yash Kothari Date: Sun, 21 Apr 2024 10:05:26 -0400 Subject: [PATCH] hash aggregate running on multiple threads --- vayu-common/src/lib.rs | 29 +++++-- vayu-common/src/store.rs | 53 ++++++------ vayu/src/df2vayu.rs | 14 ++++ vayu/src/lib.rs | 53 +++++++++--- vayu/src/sinks.rs | 2 +- vayuDB/README.md | 4 +- vayuDB/src/dummy_tasks.rs | 132 +++++++++++++++-------------- vayuDB/src/main.rs | 172 ++++++++++++++++++++++++-------------- vayuDB/src/scheduler.rs | 52 ++++++------ vayuDB/src/tpch_tasks.rs | 26 ++++-- 10 files changed, 335 insertions(+), 202 deletions(-) diff --git a/vayu-common/src/lib.rs b/vayu-common/src/lib.rs index 5360664..f6a458a 100644 --- a/vayu-common/src/lib.rs +++ b/vayu-common/src/lib.rs @@ -1,8 +1,10 @@ use arrow::record_batch::RecordBatch; use datafusion::common::Result; +use datafusion::physical_plan::aggregates::AggregateMode; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use std::sync::Arc; pub mod store; +// use vayu::operators::aggregate::AggregateOperator; pub trait PhysicalOperator { fn name(&self) -> String; @@ -26,20 +28,28 @@ pub enum SchedulerSourceType { #[derive(Clone)] pub enum SchedulerSinkType { - // StoreRecordBatch(i32), + StoreRecordBatch(i32), + // FinalAggregation(i32, AggregateOperator), BuildAndStoreHashMap(i32, Arc), PrintOutput, } +#[derive(Clone)] +pub enum FinalizeSinkType { + PrintFromStore(i32), + FinalAggregate(Arc, i32), +} +#[derive(Clone)] pub struct DatafusionPipeline { pub plan: Arc, pub sink: Option, pub id: i32, } -pub struct DatafusionPipelineWithSource { - pub source: Arc, - pub plan: Arc, - pub sink: Option, +#[derive(Clone)] +pub struct SchedulerPipeline { + pub source: Option>, + pub pipeline: DatafusionPipeline, + pub finalize: FinalizeSinkType, } pub struct DatafusionPipelineWithData { @@ -56,13 +66,18 @@ pub struct VayuPipelineWithData { pub data: RecordBatch, } pub struct Task { - pub pipelines: Vec, + pub pipelines: Vec, +} + +pub enum VayuMessage { + Normal(DatafusionPipelineWithData), + Finalize((FinalizeSinkType, i32)), } impl Task { pub fn new() -> Self { Task { pipelines: vec![] } } - pub fn add_pipeline(&mut self, pipeline: DatafusionPipelineWithSource) { + pub fn add_pipeline(&mut self, pipeline: SchedulerPipeline) { self.pipelines.push(pipeline); } } diff --git a/vayu-common/src/store.rs b/vayu-common/src/store.rs index ff3b7a0..04ec3a4 100644 --- a/vayu-common/src/store.rs +++ b/vayu-common/src/store.rs @@ -1,8 +1,9 @@ +use arrow::record_batch::RecordBatch; use core::panic; use datafusion::physical_plan::joins::hash_join::JoinLeftData; use std::collections::HashMap; pub enum Blob { - // RecordBatchBlob(Vec), + RecordBatchBlob(Vec), HashMapBlob(JoinLeftData), } @@ -13,28 +14,28 @@ impl Blob { _ => panic!("error"), } } - // pub fn get_records(self) -> Vec { - // match self { - // Blob::RecordBatchBlob(records) => records, - // _ => panic!("error"), - // } - // } - // pub fn append_records(&mut self, batches: Vec) { - // match self { - // Blob::RecordBatchBlob(records) => { - // // TODO: check if schema is same - // records.extend(batches) - // } - // _ => panic!("error"), - // } - // } + pub fn get_records(self) -> Vec { + match self { + Blob::RecordBatchBlob(records) => records, + _ => panic!("error"), + } + } + pub fn append_records(&mut self, batches: Vec) { + match self { + Blob::RecordBatchBlob(records) => { + // TODO: check if schema is same + records.extend(batches) + } + _ => panic!("error"), + } + } } // right now this is typedef of HashMap, // but we may need something else in near future pub struct Store { - store: HashMap, + pub store: HashMap, } impl Store { pub fn new() -> Store { @@ -45,15 +46,15 @@ impl Store { pub fn insert(&mut self, key: i32, value: Blob) { self.store.insert(key, value); } - // pub fn append(&mut self, key: i32, value: Vec) { - // let blob = self.remove(key); - // let mut blob = match blob { - // Some(r) => r, - // None => Blob::RecordBatchBlob(Vec::new()), - // }; - // blob.append_records(value); - // self.store.insert(key, blob); - // } + pub fn append(&mut self, key: i32, value: Vec) { + let blob = self.remove(key); + let mut blob = match blob { + Some(r) => r, + None => Blob::RecordBatchBlob(Vec::new()), + }; + blob.append_records(value); + self.store.insert(key, blob); + } pub fn remove(&mut self, key: i32) -> Option { self.store.remove(&key) // let x = self.store.remove(&key).unwrap().value(); diff --git a/vayu/src/df2vayu.rs b/vayu/src/df2vayu.rs index 69d8868..4ac0f62 100644 --- a/vayu/src/df2vayu.rs +++ b/vayu/src/df2vayu.rs @@ -179,3 +179,17 @@ pub fn get_source_node(plan: Arc) -> Arc { } panic!("No source node found"); } + +pub fn aggregate(exec: Arc) -> AggregateOperator { + let p = exec.as_any(); + let final_aggregate = if let Some(exec) = p.downcast_ref::() { + if !exec.group_by().expr().is_empty() { + panic!("group by present- not handled"); + } + let tt = AggregateOperator::new(exec); + tt + } else { + panic!("not an aggregate"); + }; + final_aggregate +} diff --git a/vayu/src/lib.rs b/vayu/src/lib.rs index 76d63a0..25aed4b 100644 --- a/vayu/src/lib.rs +++ b/vayu/src/lib.rs @@ -1,16 +1,18 @@ use arrow::array::RecordBatch; use arrow::util::pretty; use vayu_common::DatafusionPipelineWithData; +use vayu_common::IntermediateOperator; use vayu_common::VayuPipeline; pub mod operators; +use crate::operators::aggregate::AggregateOperator; +use datafusion::physical_plan::coalesce_batches::concat_batches; use std::sync::{Arc, Mutex}; - pub mod sinks; use vayu_common::store::Store; pub mod df2vayu; pub struct VayuExecutionEngine { // this is per node store - pub store: Store, + // pub store: Store, // this is global store pub global_store: Arc>, // Note: only one of them will survive lets see which @@ -19,10 +21,40 @@ pub struct VayuExecutionEngine { impl VayuExecutionEngine { pub fn new(global_store: Arc>) -> VayuExecutionEngine { VayuExecutionEngine { - store: Store::new(), + // store: Store::new(), global_store, } } + pub fn finalize(&mut self, sink: vayu_common::FinalizeSinkType) { + println!("running finalize"); + + match sink { + vayu_common::FinalizeSinkType::PrintFromStore(uuid) => { + println!("running print from store {uuid}"); + let mut store = self.global_store.lock().unwrap(); + let blob = store.remove(uuid); + println!("{:?}", store.store.keys()); + + drop(store); + let result = blob.unwrap().get_records(); + pretty::print_batches(&result).unwrap(); + } + vayu_common::FinalizeSinkType::FinalAggregate(plan, uuid) => { + println!("running FinalAggregate from store {uuid}"); + let mut store = self.global_store.lock().unwrap(); + let blob = store.remove(uuid); + println!("{:?}", store.store.keys()); + + drop(store); + let result = blob.unwrap().get_records(); + let mut operator = df2vayu::aggregate(plan); + let batch = arrow::compute::concat_batches(&result[0].schema(), &result).unwrap(); + + let result = operator.execute(&batch).unwrap(); + pretty::print_batches(&[result.clone()]).unwrap(); + } + } + } pub fn sink(&mut self, sink: vayu_common::SchedulerSinkType, result: Vec) { println!( "runningsink size {}x{}", @@ -33,17 +65,18 @@ impl VayuExecutionEngine { vayu_common::SchedulerSinkType::PrintOutput => { pretty::print_batches(&result).unwrap(); } - // vayu_common::SchedulerSinkType::StoreRecordBatch(uuid) => { - // self.store.append(uuid, result); - // } + vayu_common::SchedulerSinkType::StoreRecordBatch(uuid) => { + println!("storing at store {uuid}"); + let mut store = self.global_store.lock().unwrap(); + store.append(uuid, result); + + println!("{:?}", store.store.keys()); + drop(store); + } vayu_common::SchedulerSinkType::BuildAndStoreHashMap(uuid, join_node) => { let mut sink = sinks::HashMapSink::new(uuid, join_node); let hashmap = sink.build_map(result); println!("BuildAndStoreHashMap storing in uuid {uuid}"); - - // old store - // self.store.insert(uuid, hashmap.unwrap()); - // new store let mut map = self.global_store.lock().unwrap(); map.insert(uuid, hashmap.unwrap()); } diff --git a/vayu/src/sinks.rs b/vayu/src/sinks.rs index 0965460..f3daa99 100644 --- a/vayu/src/sinks.rs +++ b/vayu/src/sinks.rs @@ -34,7 +34,7 @@ impl HashMapSink { } } pub fn build_map(&mut self, result: Vec) -> Option { - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state: RandomState = RandomState::with_seeds(0, 0, 0, 0); let ctx: SessionContext = SessionContext::new(); let reservation = MemoryConsumer::new("HashJoinInput").register(ctx.task_ctx().memory_pool()); diff --git a/vayuDB/README.md b/vayuDB/README.md index dab2c4f..4657287 100644 --- a/vayuDB/README.md +++ b/vayuDB/README.md @@ -48,7 +48,7 @@ Scheduler will keep on sending the tasks. ## Common Vayu Structures ``` -pub struct DatafusionPipelineWithSource { +pub struct SchedulerPipeline { pub source: Arc, pub plan: Arc, pub sink: SchedulerSinkType, @@ -60,7 +60,7 @@ pub struct DatafusionPipelineWithSource { 1. Scheduler ``` pub fn new() -> Self -pub fn get_pipeline(&mut self) -> Poll +pub fn get_pipeline(&mut self) -> Poll pub fn ack_pipeline(&mut self, pipeline_id: i32); ``` diff --git a/vayuDB/src/dummy_tasks.rs b/vayuDB/src/dummy_tasks.rs index 2e06b77..286ad8b 100644 --- a/vayuDB/src/dummy_tasks.rs +++ b/vayuDB/src/dummy_tasks.rs @@ -5,84 +5,90 @@ use datafusion::physical_plan::joins::HashJoinExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::CsvReadOptions; use datafusion::prelude::SessionContext; +use futures::io::Sink; use std::path::Path; use std::sync::Arc; use vayu::df2vayu; use vayu::operators::join; -use vayu_common::DatafusionPipelineWithSource; +use vayu_common::DatafusionPipeline; +use vayu_common::SchedulerPipeline; use vayu_common::Task; -pub async fn test_filter_project_aggregate() -> Result { - // create local execution context - let ctx: SessionContext = SessionContext::new(); - // register csv file with the execution context +// pub async fn test_filter_project_aggregate() -> Result { +// // create local execution context +// let ctx: SessionContext = SessionContext::new(); +// // register csv file with the execution context - ctx.register_csv( - "aggregate_test_100", - &format!("./testing/data/csv/aggregate_test_100.csv"), - CsvReadOptions::new(), - ) - .await?; - let sql = "SELECT count(c1),sum(c3),sum(c4),count(c13) FROM aggregate_test_100 WHERE (c3 < 0 AND c1='a') OR ( c4 > 0 AND c1='b' ) "; - let plan = get_execution_plan_from_sql(&ctx, sql).await?; - let source = df2vayu::get_source_node(plan.clone()); - let mut task = Task::new(); +// ctx.register_csv( +// "aggregate_test_100", +// &format!("./testing/data/csv/aggregate_test_100.csv"), +// CsvReadOptions::new(), +// ) +// .await?; +// let sql = "SELECT count(c1),sum(c3),sum(c4),count(c13) FROM aggregate_test_100 WHERE (c3 < 0 AND c1='a') OR ( c4 > 0 AND c1='b' ) "; +// let plan = get_execution_plan_from_sql(&ctx, sql).await?; +// let source = df2vayu::get_source_node(plan.clone()); +// let mut task = Task::new(); - let pipeline = DatafusionPipelineWithSource { - source, - plan, - sink: Some(vayu_common::SchedulerSinkType::PrintOutput), - }; - task.add_pipeline(pipeline); +// let pipeline = DatafusionPipeline { +// plan, +// sink: Some(vayu_common::SchedulerSinkType::StoreRecordBatch(1)), +// id: 1, +// }; +// let finalize = Sink{ - return Ok(task); -} +// } +// let pipeline = SchedulerPipeline { source, pipeline }; +// task.add_pipeline(pipeline); -pub async fn test_hash_join() -> Result { - let ctx: SessionContext = SessionContext::new(); - // register csv file with the execution context - ctx.register_csv( - "a", - &format!("./testing/data/csv/join_test_A.csv"), - CsvReadOptions::new(), - ) - .await?; - ctx.register_csv( - "b", - &format!("./testing/data/csv/join_test_B.csv"), - CsvReadOptions::new(), - ) - .await?; +// return Ok(task); +// } - // get execution plan from th sql query - let sql = "SELECT * FROM a,b WHERE a.a1 = b.b1 "; - let plan = get_execution_plan_from_sql(&ctx, sql).await?; - let mut task = Task::new(); +// pub async fn test_hash_join() -> Result { +// let ctx: SessionContext = SessionContext::new(); +// // register csv file with the execution context +// ctx.register_csv( +// "a", +// &format!("./testing/data/csv/join_test_A.csv"), +// CsvReadOptions::new(), +// ) +// .await?; +// ctx.register_csv( +// "b", +// &format!("./testing/data/csv/join_test_B.csv"), +// CsvReadOptions::new(), +// ) +// .await?; - let uuid = 42; - let (join_node, build_plan) = df2vayu::get_hash_build_pipeline(plan.clone(), uuid); +// // get execution plan from th sql query +// let sql = "SELECT * FROM a,b WHERE a.a1 = b.b1 "; +// let plan = get_execution_plan_from_sql(&ctx, sql).await?; +// let mut task = Task::new(); - let build_source_pipeline = df2vayu::get_source_node(build_plan.clone()); - let sink = vayu_common::SchedulerSinkType::BuildAndStoreHashMap(uuid, join_node); - let build_pipeline = DatafusionPipelineWithSource { - source: build_source_pipeline, - plan: build_plan, - sink: Some(sink), - }; - task.add_pipeline(build_pipeline); - // TODO: set this uuid in probe also - let probe_plan = plan.clone(); - let probe_source_node = df2vayu::get_source_node(probe_plan.clone()); +// let uuid = 42; +// let (join_node, build_plan) = df2vayu::get_hash_build_pipeline(plan.clone(), uuid); - let probe_pipeline = DatafusionPipelineWithSource { - source: probe_source_node, - plan: probe_plan, - sink: Some(vayu_common::SchedulerSinkType::PrintOutput), - }; - task.add_pipeline(probe_pipeline); +// let build_source_pipeline = df2vayu::get_source_node(build_plan.clone()); +// let sink = vayu_common::SchedulerSinkType::BuildAndStoreHashMap(uuid, join_node); +// let build_pipeline = SchedulerPipeline { +// source: build_source_pipeline, +// plan: build_plan, +// sink: Some(sink), +// }; +// task.add_pipeline(build_pipeline); +// // TODO: set this uuid in probe also +// let probe_plan = plan.clone(); +// let probe_source_node = df2vayu::get_source_node(probe_plan.clone()); - Ok(task) -} +// let probe_pipeline = SchedulerPipeline { +// source: probe_source_node, +// plan: probe_plan, +// sink: Some(vayu_common::SchedulerSinkType::PrintOutput), +// }; +// task.add_pipeline(probe_pipeline); + +// Ok(task) +// } pub async fn get_execution_plan_from_sql( ctx: &SessionContext, diff --git a/vayuDB/src/main.rs b/vayuDB/src/main.rs index 24b8078..610c04f 100644 --- a/vayuDB/src/main.rs +++ b/vayuDB/src/main.rs @@ -1,7 +1,9 @@ use crossbeam_channel::{bounded, Receiver, Sender}; +use datafusion::arrow::array::ArrowNativeTypeOp; use std::collections::HashMap; use std::task::Poll; use std::thread; +use vayu_common::SchedulerPipeline; use vayu_common::{DatafusionPipeline, DatafusionPipelineWithData}; mod dummy_tasks; mod io_service; @@ -11,22 +13,35 @@ use std::collections::LinkedList; use std::sync::Arc; use std::sync::Mutex; use vayu_common; +use vayu_common::VayuMessage; use vayu_common::store::Store; + +use crate::scheduler::Scheduler; fn start_worker( - receiver: Receiver, + receiver: Receiver, sender: Sender<(usize, i32)>, global_store: Arc>, thread_id: usize, ) { let mut executor = vayu::VayuExecutionEngine::new(global_store); // Receive structs sent over the channel - while let Ok(pipeline) = receiver.recv() { - let pipeline_id = pipeline.pipeline.id; - println!("{thread_id}:got a pipeline for the thread, executing ..."); - executor.execute(pipeline); - println!("{thread_id}:done executing ..."); - sender.send((thread_id, pipeline_id)).unwrap(); + while let Ok(message) = receiver.recv() { + match message { + VayuMessage::Normal(pipeline) => { + let pipeline_id = pipeline.pipeline.id; + println!("{thread_id}:got a pipeline for the thread, executing ..."); + executor.execute(pipeline); + println!("{thread_id}:done executing ..."); + sender.send((thread_id, pipeline_id)).unwrap(); + } + VayuMessage::Finalize((sink, pipeline_id)) => { + println!("{thread_id}:got a finalize pipeline for the thread, executing ..."); + executor.finalize(sink); + println!("{thread_id}:done executing ..."); + sender.send((thread_id, pipeline_id)).unwrap(); + } + } } } @@ -49,7 +64,7 @@ fn main() { bounded(0); // vector to store main_thread->worker channels - let mut senders: Vec> = Vec::new(); + let mut senders: Vec> = Vec::new(); let mut free_threads: LinkedList = LinkedList::new(); for thread_num in 0..num_threads { @@ -82,77 +97,110 @@ fn main() { let mut io_service = io_service::IOService::new(); // TODO: create task_queue - buffer tasks - let mut request_pipeline_map: HashMap = HashMap::new(); + let mut request_pipeline_map: HashMap = HashMap::new(); + let mut completed_pipeline_list: LinkedList = LinkedList::new(); // right now a pipeline would be assigned to a worker only when it is free // but we will poll some extra pipelines from the scheduler and send it to the io service // so that we can start working on it once any worker is free let mut next_id = 0; - if request_pipeline_map.len() == 0 { - // poll scheduler for a new task - let pipeline = scheduler.get_pipeline(next_id); - if let Poll::Ready(pipeline) = pipeline { - // TODO: add support for multiple dependent pipeline - println!("got a pipeline from scheduler"); - - // submit the source request to io service - let request_num = io_service.submit_request(pipeline.source); - println!("sent the request to the io_service"); - - // insert the pipeline into the local map - request_pipeline_map.insert( - request_num, - DatafusionPipeline { - plan: pipeline.plan, - sink: pipeline.sink, - id: next_id, - }, - ); - next_id += 1; - } + // poll scheduler for a new task + let pipeline = scheduler.get_pipeline(next_id); + if let Poll::Ready(mut pipeline) = pipeline { + // TODO: add support for multiple dependent pipeline + println!("got a pipeline from scheduler"); + + let source = pipeline.source.take().unwrap(); + // submit the source request to io service + let request_num = io_service.submit_request(source); + println!("sent the request to the io_service"); + + // insert the pipeline into the local map + request_pipeline_map.insert(request_num, (pipeline, 0)); + next_id += 1; } + loop { - if let Ok((thread_id, finished_pipeline_id)) = informer_receiver.try_recv() { + if let Ok((thread_id, request_num)) = informer_receiver.try_recv() { println!("got ack from thread {}", thread_id); - if finished_pipeline_id != -1 { - scheduler.ack_pipeline(finished_pipeline_id); + if request_num != -1 { + let pipeline = request_pipeline_map.remove(&request_num); + match pipeline { + Some((pipeline, processing_count)) => { + let processing_count = processing_count - 1; + println!("current processing count is {processing_count}"); + if processing_count == 0 { + completed_pipeline_list.push_back(pipeline); + } else { + request_pipeline_map.insert(request_num, (pipeline, processing_count)); + } + } + None => { + println!("inform scheduler we are done"); + // scheduler.ack_pipeline(request_num); + } + } } + // add in the queue free_threads.push_back(thread_id); } - if let Some(&thread_id) = free_threads.front() { - // println!("free thread available"); - // poll io_service for a response - let response = io_service.poll_response(); - if let Poll::Ready((request_num, data)) = response { - if data.is_none() { - let pipeline = request_pipeline_map.remove(&request_num); - } else { - let data = data.unwrap(); - free_threads.pop_front(); - println!("got a response from the io_service"); - - // TODO: handle when a source gives multiple record batches - // get the pipeline from the local map + if free_threads.len() == 0 { + continue; + } + // check from finalize queue + if !completed_pipeline_list.is_empty() { + println!("removing item from completed list"); + + let thread_id = free_threads.pop_front().unwrap(); + // pipeline. + //data is finished + let pipeline = completed_pipeline_list.pop_front().unwrap(); + let msg = VayuMessage::Finalize((pipeline.finalize, pipeline.pipeline.id)); + senders[thread_id].send(msg).expect("Failed to send struct"); + println!("finalize:sent the pipeline and the data to the worker"); + continue; + } + // println!("free thread available"); + // poll io_service for a response + let response = io_service.poll_response(); + if let Poll::Ready((request_num, data)) = response { + if data.is_none() { + let mv = request_pipeline_map.get(&request_num); + + assert!(mv.is_some()); + let (_, processing_count) = mv.unwrap(); + println!("no more data left. processing count is {processing_count}"); + if processing_count.is_zero() { let pipeline = request_pipeline_map.remove(&request_num); - assert!(pipeline.is_some()); - let pipeline = pipeline.unwrap(); - let pipeline2 = DatafusionPipeline { - plan: pipeline.plan.clone(), - sink: pipeline.sink.clone(), - id: pipeline.id, - }; - request_pipeline_map.insert(request_num, pipeline2); - // send over channel - let msg = DatafusionPipelineWithData { pipeline, data }; - senders[thread_id].send(msg).expect("Failed to send struct"); - println!("sent the pipeline and the data to the worker"); + let (pipeline, _) = pipeline.unwrap(); + completed_pipeline_list.push_back(pipeline); } - // assign the next pipeline to some other worker - // worker_id = round_robin(worker_id, num_threads); + } else { + let data = data.unwrap(); + let thread_id = free_threads.pop_front().unwrap(); + println!("got a response from the io_service"); + + // get the pipeline from the local map + let pipeline = request_pipeline_map.remove(&request_num); + + assert!(pipeline.is_some()); + let (pipeline, processing_count) = pipeline.unwrap(); + + request_pipeline_map.insert(request_num, (pipeline.clone(), processing_count + 1)); + + // send over channel + let msg = VayuMessage::Normal(DatafusionPipelineWithData { + pipeline: pipeline.pipeline, + data, + }); + senders[thread_id].send(msg).expect("Failed to send struct"); + println!("sent the pipeline and the data to the worker"); } + // assign the next pipeline to some other worker + // worker_id = round_robin(worker_id, num_threads); } } } diff --git a/vayuDB/src/scheduler.rs b/vayuDB/src/scheduler.rs index d24f686..7b7d35b 100644 --- a/vayuDB/src/scheduler.rs +++ b/vayuDB/src/scheduler.rs @@ -1,8 +1,8 @@ -use crate::dummy_tasks::{test_filter_project_aggregate, test_hash_join}; +// use crate::dummy_tasks::test_hash_join; use crate::tpch_tasks::test_tpchq1; use datafusion_benchmarks::tpch; use std::{hash::Hash, task::Poll}; -use vayu_common::DatafusionPipelineWithSource; +use vayu_common::SchedulerPipeline; #[derive(PartialEq)] enum HashJoinState { CanSendBuild, @@ -14,7 +14,7 @@ pub struct Scheduler { turn: usize, // stored_id: i32, state: HashJoinState, - probe_pipeline: Option, + probe_pipeline: Option, } impl Scheduler { @@ -26,34 +26,34 @@ impl Scheduler { } } - pub fn get_pipeline(&mut self, id: i32) -> Poll { - // let mut task = futures::executor::block_on(test_tpchq1()).unwrap(); - // let pipeline = task.pipelines.remove(0); - // return Poll::Ready(pipeline); - - let mut task = futures::executor::block_on(test_filter_project_aggregate()).unwrap(); + pub fn get_pipeline(&mut self, id: i32) -> Poll { + let mut task = futures::executor::block_on(test_tpchq1()).unwrap(); let pipeline = task.pipelines.remove(0); return Poll::Ready(pipeline); + // let mut task = futures::executor::block_on(test_filter_project_aggregate()).unwrap(); + // let pipeline = task.pipelines.remove(0); + // return Poll::Ready(pipeline); + self.turn = 1 - self.turn; - if self.turn == 0 && self.state == HashJoinState::CanSendBuild { - let mut task = futures::executor::block_on(test_hash_join()).unwrap(); - self.probe_pipeline = Some(task.pipelines.remove(1)); - let build_pipeline = task.pipelines.remove(0); + // if self.turn == 0 && self.state == HashJoinState::CanSendBuild { + // let mut task = futures::executor::block_on(test_hash_join()).unwrap(); + // self.probe_pipeline = Some(task.pipelines.remove(1)); + // let build_pipeline = task.pipelines.remove(0); - self.state = HashJoinState::BuildSent(id); - return Poll::Ready(build_pipeline); - } else if self.turn == 0 && self.state == HashJoinState::CanSendProbe { - self.state = HashJoinState::ProbeSent(id); - assert!(self.probe_pipeline.is_some()); - let probe_pipeline = self.probe_pipeline.take().unwrap(); - return Poll::Ready(probe_pipeline); - } else { - let mut task = futures::executor::block_on(test_filter_project_aggregate()).unwrap(); - let pipeline = task.pipelines.remove(0); - return Poll::Ready(pipeline); - // return Poll::Pending; - } + // self.state = HashJoinState::BuildSent(id); + // return Poll::Ready(build_pipeline); + // } else if self.turn == 0 && self.state == HashJoinState::CanSendProbe { + // self.state = HashJoinState::ProbeSent(id); + // assert!(self.probe_pipeline.is_some()); + // let probe_pipeline = self.probe_pipeline.take().unwrap(); + // return Poll::Ready(probe_pipeline); + // } else { + // let mut task = futures::executor::block_on(test_filter_project_aggregate()).unwrap(); + // let pipeline = task.pipelines.remove(0); + // return Poll::Ready(pipeline); + // // return Poll::Pending; + // } } pub fn ack_pipeline(&mut self, ack_id: i32) { match self.state { diff --git a/vayuDB/src/tpch_tasks.rs b/vayuDB/src/tpch_tasks.rs index 42730f0..62a47ed 100644 --- a/vayuDB/src/tpch_tasks.rs +++ b/vayuDB/src/tpch_tasks.rs @@ -11,7 +11,10 @@ use std::path::Path; use std::path::PathBuf; use std::process::exit; use vayu::df2vayu; -use vayu_common::DatafusionPipelineWithSource; +use vayu::operators::aggregate::AggregateOperator; +use vayu_common::DatafusionPipeline; +use vayu_common::SchedulerPipeline; +use vayu_common::SchedulerSinkType; use vayu_common::Task; fn get_tpch_data_path() -> Result { let path = std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); @@ -76,18 +79,31 @@ pub async fn test_tpchq1() -> Result { let sql = queries.get(0).unwrap(); let plan = get_execution_plan_from_sql(&ctx, sql).await.unwrap(); + let final_aggregate = plan.clone(); + // let final_aggregate = AggregateOperator::new(final_aggregate); + + let plan = plan.children().get(0).unwrap().clone(); println!( "=== Physical plan ===\n{}\n", displayable(plan.as_ref()).indent(true) ); - let source = df2vayu::get_source_node(plan.clone()); + let source = Some(df2vayu::get_source_node(plan.clone())); let mut task = Task::new(); - let pipeline = DatafusionPipelineWithSource { - source, + let uuid = 55; + let pipeline = DatafusionPipeline { plan, - sink: Some(vayu_common::SchedulerSinkType::PrintOutput), + sink: Some(SchedulerSinkType::StoreRecordBatch(uuid)), + id: 1, + }; + + let finalize = vayu_common::FinalizeSinkType::FinalAggregate(final_aggregate, uuid); + let pipeline = SchedulerPipeline { + source, + pipeline, + finalize, }; + task.add_pipeline(pipeline); return Ok(task);