Skip to content

Commit

Permalink
[FEAT] connect: add parquet support
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent 4f49d30 commit a3217f3
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ chrono-tz = "0.8.4"
comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-error = {path = "src/common/error", default-features = false}
common-file-formats = {path = "src/common/file-formats"}
daft-core = {path = "src/daft-core"}
daft-dsl = {path = "src/daft-dsl"}
daft-hash = {path = "src/daft-hash"}
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
arrow2 = {workspace = true}
async-stream = "0.3.6"
common-daft-config = {workspace = true}
common-file-formats = {workspace = true}
daft-core = {workspace = true}
daft-dsl = {workspace = true}
daft-local-execution = {workspace = true}
Expand Down
5 changes: 3 additions & 2 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ impl SparkConnectService for DaftSparkConnectService {
CommandType::RegisterFunction(_) => {
unimplemented_err!("RegisterFunction not implemented")
}
CommandType::WriteOperation(_) => {
unimplemented_err!("WriteOperation not implemented")
CommandType::WriteOperation(op) => {
let result = session.handle_write_command(op, operation).await?;
return Ok(Response::new(result))
}
CommandType::CreateDataframeView(_) => {
unimplemented_err!("CreateDataframeView not implemented")
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/src/op/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use uuid::Uuid;
use crate::{DaftSparkConnectService, Session};

mod root;
mod write;

pub type ExecuteStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

Expand Down
206 changes: 206 additions & 0 deletions src/daft-connect/src/op/execute/write.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use std::{collections::HashMap, future::ready};

use common_daft_config::DaftExecutionConfig;
use common_file_formats::FileFormat;
use eyre::{bail, WrapErr};
use futures::stream;
use spark_connect::{
write_operation::{SaveMode, SaveType},
ExecutePlanResponse, Relation, WriteOperation,
};
use tokio_util::sync::CancellationToken;
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
use tracing::warn;

use crate::{
invalid_argument_err,
op::execute::{ExecuteStream, PlanIds},
session::Session,
translation,
};

impl Session {
pub async fn handle_write_command(
&self,
operation: WriteOperation,
operation_id: String,
) -> Result<ExecuteStream, Status> {
use futures::{StreamExt, TryStreamExt};

let context = PlanIds {
session: self.client_side_session_id().to_string(),
server_side_session: self.server_side_session_id().to_string(),
operation: operation_id,
};

let finished = context.finished();

// operation: WriteOperation {
// input: Some(
// Relation {
// common: Some(
// RelationCommon {
// source_info: "",
// plan_id: Some(
// 0,
// ),
// origin: None,
// },
// ),
// rel_type: Some(
// Range(
// Range {
// start: Some(
// 0,
// ),
// end: 10,
// step: 1,
// num_partitions: None,
// },
// ),
// ),
// },
// ),
// source: Some(
// "parquet",
// ),
// mode: Unspecified,
// sort_column_names: [],
// partitioning_columns: [],
// bucket_by: None,
// options: {},
// clustering_columns: [],
// save_type: Some(
// Path(
// "/var/folders/zy/g1zccty96bg_frmz9x0198zh0000gn/T/tmpxki7yyr0/test.parquet",
// ),
// ),
// }

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(16);
std::thread::spawn(move || {
let result = (|| -> eyre::Result<()> {
let WriteOperation {
input,
source,
mode,
sort_column_names,
partitioning_columns,
bucket_by,
options,
clustering_columns,
save_type,
} = operation;

let Some(input) = input else {
bail!("Input is required");
};

let Some(source) = source else {
bail!("Source is required");
};

if source != "parquet" {
bail!("Unsupported source: {source}; only parquet is supported");
}

let Ok(mode) = SaveMode::try_from(mode) else {
bail!("Invalid save mode: {mode}");
};

if !sort_column_names.is_empty() {
// todo(completeness): implement sort
warn!(
"Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"
);
}

if !partitioning_columns.is_empty() {
// todo(completeness): implement partitioning
warn!("Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)");
}

if let Some(bucket_by) = bucket_by {
// todo(completeness): implement bucketing
warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)");
}

if !options.is_empty() {
// todo(completeness): implement options
warn!("Ignoring options: {options:?} (not yet implemented)");
}

if !clustering_columns.is_empty() {
// todo(completeness): implement clustering
warn!(
"Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"
);
}

match mode {
SaveMode::Unspecified => {}
SaveMode::Append => {}
SaveMode::Overwrite => {}
SaveMode::ErrorIfExists => {}
SaveMode::Ignore => {}
}

let Some(save_type) = save_type else {
return bail!("Save type is required");
};

let path = match save_type {
SaveType::Path(path) => path,
SaveType::Table(table) => {
let name = table.table_name;
bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead.");
}
};

let plan = translation::to_logical_plan(input)?;

let plan = plan
.table_write(&path, FileFormat::Parquet, None, None, None)
.wrap_err("Failed to create table write plan")?;

let logical_plan = plan.build();
let physical_plan = daft_local_plan::translate(&logical_plan)?;

let cfg = DaftExecutionConfig::default();

// "hot" flow not a "cold" flow
let iterator = daft_local_execution::run_local(
&physical_plan,
HashMap::new(),
cfg.into(),
None,
CancellationToken::new(), // todo: maybe implement cancelling
)?;

for _ignored in iterator {

}

// this is so we make sure the operation is actually done
// before we return
//
// an example where this is important is if we write to a parquet file
// and then read immediately after, we need to wait for the write to finish

Ok(())
})();

if let Err(e) = result {
tx.blocking_send(Err(e)).unwrap();
}
});

let stream = ReceiverStream::new(rx);

let stream = stream
.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}")))
.chain(stream::once(ready(Ok(finished))));

Ok(Box::pin(stream))
}
}
4 changes: 3 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ use eyre::{bail, Context};
use spark_connect::{relation::RelType, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range};
use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, read::read};

mod aggregate;
mod project;
mod read;
mod range;

pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
Expand All @@ -24,6 +25,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::Read(r) => read(r).wrap_err("Failed to apply table read to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
29 changes: 29 additions & 0 deletions src/daft-connect/src/translation/logical_plan/read.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, WrapErr};
use spark_connect::read::ReadType;
use tracing::warn;

mod data_source;

pub fn read(read: spark_connect::Read) -> eyre::Result<LogicalPlanBuilder> {
let spark_connect::Read {
is_streaming,
read_type,
} = read;

warn!("Ignoring is_streaming: {is_streaming}");

let Some(read_type) = read_type else {
bail!("Read type is required");
};

match read_type {
ReadType::NamedTable(table) => {
let name = table.unparsed_identifier;
bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead.");
}
ReadType::DataSource(source) => {
data_source::data_source(source).wrap_err("Failed to create data source")
}
}
}
42 changes: 42 additions & 0 deletions src/daft-connect/src/translation/logical_plan/read/data_source.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use daft_logical_plan::LogicalPlanBuilder;
use daft_scan::builder::ParquetScanBuilder;
use eyre::{bail, ensure, WrapErr};
use tracing::warn;

pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result<LogicalPlanBuilder> {
let spark_connect::read::DataSource {
format,
schema,
options,
paths,
predicates,
} = data_source;

let Some(format) = format else {
bail!("Format is required");
};

if format != "parquet" {
bail!("Unsupported format: {format}; only parquet is supported");
}

ensure!(!paths.is_empty(), "Paths are required");

if let Some(schema) = schema {
warn!("Ignoring schema: {schema:?}; not yet implemented");
}

if !options.is_empty() {
warn!("Ignoring options: {options:?}; not yet implemented");
}

if !predicates.is_empty() {
warn!("Ignoring predicates: {predicates:?}; not yet implemented");
}

let builder = ParquetScanBuilder::new(paths)
.finish()
.wrap_err("Failed to create parquet scan builder")?;

Ok(builder)
}
36 changes: 36 additions & 0 deletions tests/connect/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import tempfile
import shutil
import os


def test_write_parquet(spark_session):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Create DataFrame from range(10)
df = spark_session.range(10)

# Write DataFrame to parquet directory
parquet_dir = os.path.join(temp_dir, "test.parquet")
df.write.parquet(parquet_dir)

# List all files in the parquet directory
parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith('.parquet')]
print(f"Parquet files in directory: {parquet_files}")

# Assert there is at least one parquet file
assert len(parquet_files) > 0, "Expected at least one parquet file to be written"

# Read back from the parquet directory (not specific file)
df_read = spark_session.read.parquet(parquet_dir)

# Verify the data is unchanged
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read"

finally:
# Clean up temp directory
shutil.rmtree(temp_dir)

0 comments on commit a3217f3

Please sign in to comment.