diff --git a/Cargo.lock b/Cargo.lock index fd22dcfa10..09b6dfd699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1931,6 +1931,7 @@ dependencies = [ "arrow2", "async-stream", "common-daft-config", + "common-file-formats", "daft-core", "daft-dsl", "daft-local-execution", diff --git a/Cargo.toml b/Cargo.toml index 67334d8b0d..4af4918398 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,6 +194,7 @@ chrono-tz = "0.10.0" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} common-runtime = {path = "src/common/runtime", default-features = false} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 9651106968..7e085df7f5 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -2,6 +2,7 @@ arrow2 = {workspace = true} async-stream = "0.3.6" common-daft-config = {workspace = true} +common-file-formats = {workspace = true} daft-core = {workspace = true} daft-dsl = {workspace = true} daft-local-execution = {workspace = true} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 70171ad0d4..3ba978a72f 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -190,8 +190,9 @@ impl SparkConnectService for DaftSparkConnectService { CommandType::RegisterFunction(_) => { unimplemented_err!("RegisterFunction not implemented") } - CommandType::WriteOperation(_) => { - unimplemented_err!("WriteOperation not implemented") + CommandType::WriteOperation(op) => { + let result = session.handle_write_command(op, operation).await?; + return Ok(Response::new(result)); } CommandType::CreateDataframeView(_) => { unimplemented_err!("CreateDataframeView not implemented") diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index fba3cc850d..41baf88b09 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -11,6 +11,7 @@ use uuid::Uuid; use crate::{DaftSparkConnectService, Session}; mod root; +mod write; pub type ExecuteStream = ::ExecutePlanStream; diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs new file mode 100644 index 0000000000..8e21433bd0 --- /dev/null +++ b/src/daft-connect/src/op/execute/write.rs @@ -0,0 +1,144 @@ +use std::{collections::HashMap, future::ready}; + +use common_daft_config::DaftExecutionConfig; +use common_file_formats::FileFormat; +use daft_local_execution::NativeExecutor; +use eyre::{bail, WrapErr}; +use spark_connect::{ + write_operation::{SaveMode, SaveType}, + WriteOperation, +}; +use tonic::Status; +use tracing::warn; + +use crate::{ + op::execute::{ExecuteStream, PlanIds}, + session::Session, + translation, +}; + +impl Session { + pub async fn handle_write_command( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + use futures::StreamExt; + + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + operation: operation_id, + }; + + let finished = context.finished(); + + let result = async move { + let WriteOperation { + input, + source, + mode, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + } = operation; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let Some(source) = source else { + bail!("Source is required"); + }; + + if source != "parquet" { + bail!("Unsupported source: {source}; only parquet is supported"); + } + + let Ok(mode) = SaveMode::try_from(mode) else { + bail!("Invalid save mode: {mode}"); + }; + + if !sort_column_names.is_empty() { + // todo(completeness): implement sort + warn!("Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"); + } + + if !partitioning_columns.is_empty() { + // todo(completeness): implement partitioning + warn!( + "Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)" + ); + } + + if let Some(bucket_by) = bucket_by { + // todo(completeness): implement bucketing + warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); + } + + if !options.is_empty() { + // todo(completeness): implement options + warn!("Ignoring options: {options:?} (not yet implemented)"); + } + + if !clustering_columns.is_empty() { + // todo(completeness): implement clustering + warn!("Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => {} + SaveMode::Overwrite => {} + SaveMode::ErrorIfExists => {} + SaveMode::Ignore => {} + } + + let Some(save_type) = save_type else { + bail!("Save type is required"); + }; + + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(table) => { + let name = table.table_name; + bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); + } + }; + + let plan = translation::to_logical_plan(input)?; + + let plan = plan + .table_write(&path, FileFormat::Parquet, None, None, None) + .wrap_err("Failed to create table write plan")?; + + let optimized_plan = plan.optimize()?; + let cfg = DaftExecutionConfig::default(); + let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; + let mut result_stream = native_executor + .run(HashMap::new(), cfg.into(), None)? + .into_stream(); + + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + while let Some(_result) = result_stream.next().await {} + + Ok(()) + }; + + use futures::TryFutureExt; + + let result = result.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))); + + let future = result.and_then(|_| ready(Ok(finished))); + let stream = futures::stream::once(future); + + Ok(Box::pin(stream)) + } +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 93c9e9bd4a..4fdfa4e00e 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,14 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, read::read, +}; mod aggregate; mod project; mod range; +mod read; pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -25,6 +28,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::Read(r) => read(r).wrap_err("Failed to apply table read to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs new file mode 100644 index 0000000000..199d77da4b --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read.rs @@ -0,0 +1,29 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, WrapErr}; +use spark_connect::read::ReadType; +use tracing::warn; + +mod data_source; + +pub fn read(read: spark_connect::Read) -> eyre::Result { + let spark_connect::Read { + is_streaming, + read_type, + } = read; + + warn!("Ignoring is_streaming: {is_streaming}"); + + let Some(read_type) = read_type else { + bail!("Read type is required"); + }; + + match read_type { + ReadType::NamedTable(table) => { + let name = table.unparsed_identifier; + bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead."); + } + ReadType::DataSource(source) => { + data_source::data_source(source).wrap_err("Failed to create data source") + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs new file mode 100644 index 0000000000..25642e35ee --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -0,0 +1,44 @@ +use daft_logical_plan::LogicalPlanBuilder; +use daft_scan::builder::ParquetScanBuilder; +use eyre::{bail, ensure, WrapErr}; +use tracing::warn; + +pub fn data_source( + data_source: spark_connect::read::DataSource, +) -> eyre::Result { + let spark_connect::read::DataSource { + format, + schema, + options, + paths, + predicates, + } = data_source; + + let Some(format) = format else { + bail!("Format is required"); + }; + + if format != "parquet" { + bail!("Unsupported format: {format}; only parquet is supported"); + } + + ensure!(!paths.is_empty(), "Paths are required"); + + if let Some(schema) = schema { + warn!("Ignoring schema: {schema:?}; not yet implemented"); + } + + if !options.is_empty() { + warn!("Ignoring options: {options:?}; not yet implemented"); + } + + if !predicates.is_empty() { + warn!("Ignoring predicates: {predicates:?}; not yet implemented"); + } + + let builder = ParquetScanBuilder::new(paths) + .finish() + .wrap_err("Failed to create parquet scan builder")?; + + Ok(builder) +} diff --git a/tests/connect/test_parquet.py b/tests/connect/test_parquet.py new file mode 100644 index 0000000000..b356254fdf --- /dev/null +++ b/tests/connect/test_parquet.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import os +import shutil +import tempfile + + +def test_write_parquet(spark_session): + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Write DataFrame to parquet directory + parquet_dir = os.path.join(temp_dir, "test.parquet") + df.write.parquet(parquet_dir) + + # List all files in the parquet directory + parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith(".parquet")] + print(f"Parquet files in directory: {parquet_files}") + + # Assert there is at least one parquet file + assert len(parquet_files) > 0, "Expected at least one parquet file to be written" + + # Read back from the parquet directory (not specific file) + df_read = spark_session.read.parquet(parquet_dir) + + # Verify the data is unchanged + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" + + finally: + # Clean up temp directory + shutil.rmtree(temp_dir)