From fd179b82f5196645865428f583253ea3f502f775 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 13 Nov 2024 16:11:55 -0800 Subject: [PATCH] simplify a lot of redundant logic --- src/daft-connect/src/op/execute.rs | 22 +-- src/daft-connect/src/op/execute/root.rs | 81 ++++++++++- src/daft-connect/src/op/execute/write.rs | 10 +- src/daft-connect/src/translation.rs | 2 - src/daft-connect/src/translation/stream.rs | 130 ------------------ .../src/translation/stream/range.rs | 47 ------- src/daft-logical-plan/src/builder.rs | 1 + 7 files changed, 87 insertions(+), 206 deletions(-) delete mode 100644 src/daft-connect/src/translation/stream.rs delete mode 100644 src/daft-connect/src/translation/stream/range.rs diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index a8714433ed..ecc660aa9f 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -1,31 +1,17 @@ -use std::{ - collections::{HashMap, HashSet}, - future::ready, - path::PathBuf, - pin::Pin, -}; +use std::pin::Pin; use arrow2::io::ipc::write::StreamWriter; -use common_daft_config::DaftExecutionConfig; -use common_file_formats::FileFormat; -use daft_scan::builder::{parquet_scan, ParquetScanBuilder}; use daft_table::Table; use eyre::Context; -use futures::{stream, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use spark_connect::{ execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, spark_connect_service_server::SparkConnectService, - write_operation::{SaveMode, SaveType}, - ExecutePlanResponse, Relation, WriteOperation, + ExecutePlanResponse, }; -use tonic::Status; -use tracing::{error, warn}; use uuid::Uuid; -use crate::{ - invalid_argument_err, not_found_err, translation::relation_to_stream, DaftSparkConnectService, - Session, -}; +use crate::{DaftSparkConnectService, Session}; mod root; mod write; diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 5abe6568a4..673415b3d3 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,13 +1,14 @@ use std::future::ready; +use common_daft_config::DaftExecutionConfig; use futures::stream; -use spark_connect::Relation; -use tonic::Status; +use spark_connect::{ExecutePlanResponse, Relation}; +use tonic::{codegen::tokio_stream::wrappers::UnboundedReceiverStream, Status}; use crate::{ op::execute::{ExecuteStream, PlanIds}, session::Session, - translation::relation_to_stream, + translation, }; impl Session { @@ -21,13 +22,81 @@ impl Session { let context = PlanIds { session: self.client_side_session_id().to_string(), server_side_session: self.server_side_session_id().to_string(), - operation: operation_id.clone(), + operation: operation_id, }; let finished = context.finished(); - let stream = relation_to_stream(command, context) - .map_err(|e| Status::internal(e.to_string()))? + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + + std::thread::spawn(move || { + let plan = match translation::to_logical_plan(command) { + Ok(plan) => plan, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + let logical_plan = plan.logical_plan.build(); + let physical_plan = match daft_local_plan::translate(&logical_plan) { + Ok(plan) => plan, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + let cfg = DaftExecutionConfig::default(); + let result = match daft_local_execution::run_local( + &physical_plan, + plan.partition, + cfg.into(), + None, + ) { + Ok(result) => result, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + for result in result { + let result = match result { + Ok(result) => result, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + let tables = match result.get_tables() { + Ok(tables) => tables, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + for table in tables.as_slice() { + let response = context.gen_response(table); + + let response = match response { + Ok(response) => response, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + tx.send(Ok(response)).unwrap(); + } + } + }); + + let stream = UnboundedReceiverStream::new(rx); + + let stream = stream .map_err(|e| Status::internal(e.to_string())) .chain(stream::once(ready(Ok(finished)))); diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs index 7d8013f33e..8f72a2da33 100644 --- a/src/daft-connect/src/op/execute/write.rs +++ b/src/daft-connect/src/op/execute/write.rs @@ -136,9 +136,13 @@ impl Session { println!("physical plan: {physical_plan:#?}"); let cfg = DaftExecutionConfig::default(); - let results = - daft_local_execution::run_local(&physical_plan, plan_builder.partition, cfg.into(), None) - .unwrap(); + let results = daft_local_execution::run_local( + &physical_plan, + plan_builder.partition, + cfg.into(), + None, + ) + .unwrap(); // todo: remove std::thread::scope(|s| { diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index b36efcbd47..125aa6e884 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -2,8 +2,6 @@ mod logical_plan; mod schema; -mod stream; pub use logical_plan::to_logical_plan; pub use schema::relation_to_schema; -pub use stream::relation_to_stream; diff --git a/src/daft-connect/src/translation/stream.rs b/src/daft-connect/src/translation/stream.rs deleted file mode 100644 index f2f3a35034..0000000000 --- a/src/daft-connect/src/translation/stream.rs +++ /dev/null @@ -1,130 +0,0 @@ -//! Relation handling for Spark Connect protocol. -//! -//! A Relation represents a structured dataset or transformation in Spark Connect. -//! It can be either a base relation (direct data source) or derived relation -//! (result of operations on other relations). -//! -//! The protocol represents relations as trees of operations where: -//! - Each node is a Relation with metadata and an operation type -//! - Operations can reference other relations, forming a DAG -//! - The tree describes how to derive the final result -//! -//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age -//! -//! ```text -//! Aggregate (grouping by age) -//! ↳ Filter (department = 'Engineering') -//! ↳ Read (employees table) -//! ``` -//! -//! Relations abstract away: -//! - Physical storage details -//! - Distributed computation -//! - Query optimization -//! - Data source specifics -//! -//! This allows Spark to optimize and execute queries efficiently across a cluster -//! while providing a consistent API regardless of the underlying data source. -//! ```mermaid -//! -//! ``` - -use std::collections::HashMap; - -use common_daft_config::DaftExecutionConfig; -use common_error::DaftError; -use eyre::{eyre, Context}; -use futures::{Stream, StreamExt, TryStreamExt}; -use spark_connect::{relation::RelType, ExecutePlanResponse, Relation}; -use tonic::codegen::tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream}; -use tracing::trace; - -mod range; -use range::range; - -use crate::{ - op::execute::{ExecuteRichStream, ExecuteStream, PlanIds}, - translation, - translation::to_logical_plan, -}; - -pub fn relation_to_stream(relation: Relation, context: PlanIds) -> eyre::Result { - // First check common fields if needed - if let Some(common) = &relation.common { - // contains metadata shared across all relation types - // Log or handle common fields if necessary - trace!("Processing relation with plan_id: {:?}", common.plan_id); - } - - let rel_type = relation.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; - - match rel_type { - RelType::Range(input) => { - let stream = range(input, &context).wrap_err("parsing Range")?; - Ok(Box::pin(stream)) - } - RelType::Read(read) => { - let builder = translation::logical_plan::read(read)?; - let logical_plan = builder.logical_plan.build(); - - let (tx, rx) = - tokio::sync::mpsc::unbounded_channel::>(); - - std::thread::spawn(move || { - let physical_plan = match daft_local_plan::translate(&logical_plan) { - Ok(plan) => plan, - Err(e) => { - tx.send(Err(eyre!(e))).unwrap(); - return; - } - }; - - let cfg = DaftExecutionConfig::default(); - let result = daft_local_execution::run_local( - &physical_plan, - builder.partition, - cfg.into(), - None, - ) - .unwrap(); - - for result in result { - let result = match result { - Ok(result) => result, - Err(e) => { - tx.send(Err(eyre!(e))).unwrap(); - continue; - } - }; - - let tables = match result.get_tables() { - Ok(tables) => tables, - Err(e) => { - tx.send(Err(eyre!(e))).unwrap(); - continue; - } - }; - - for table in tables.as_slice() { - let response = context.gen_response(table); - - let response = match response { - Ok(response) => response, - Err(e) => { - tx.send(Err(eyre!(e))).unwrap(); - continue; - } - }; - - tx.send(Ok(response)).unwrap(); - } - } - }); - - let recv_stream = UnboundedReceiverStream::new(rx); - - Ok(Box::pin(recv_stream)) - } - other => Err(eyre!("Unsupported top-level relation: {other:?}")), - } -} diff --git a/src/daft-connect/src/translation/stream/range.rs b/src/daft-connect/src/translation/stream/range.rs deleted file mode 100644 index 0c68b10353..0000000000 --- a/src/daft-connect/src/translation/stream/range.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::future::ready; - -use daft_core::prelude::Series; -use daft_schema::prelude::Schema; -use daft_table::Table; -use eyre::{ensure, Context}; -use futures::{stream, Stream}; -use spark_connect::{ExecutePlanResponse, Range}; - -use crate::op::execute::PlanIds; - -pub fn range( - range: Range, - context: &PlanIds, -) -> eyre::Result> + Unpin> { - let Range { - start, - end, - step, - num_partitions, - } = range; - - let start = start.unwrap_or(0); - ensure!(num_partitions.is_none(), "num_partitions is not supported"); - - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); - - let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect(); - let len = arrow_array.len(); - - let singleton_series = Series::try_from(( - "range", - Box::new(arrow_array) as Box, - )) - .wrap_err("creating singleton series")?; - - let singleton_table = Table::new_with_size( - Schema::new(vec![singleton_series.field().clone()])?, - vec![singleton_series], - len, - )?; - - let response = context.gen_response(&singleton_table)?; - - Ok(stream::once(ready(Ok(response)))) -} diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 13dfa44c8b..1a7638af8d 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -144,6 +144,7 @@ impl LogicalPlanBuilder { size_bytes: usize, num_rows: usize, ) -> Self { + use crate::InMemoryInfo; let source_info = SourceInfo::InMemory(InMemoryInfo::new_not_python( schema.clone(), partition_key.into(),