diff --git a/Cargo.lock b/Cargo.lock index 33c6acd968..bd5e776001 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1932,11 +1932,14 @@ name = "daft-connect" version = "0.3.0-dev0" dependencies = [ "arrow2", + "async-stream", "common-daft-config", + "daft-core", + "daft-dsl", "daft-local-execution", - "daft-local-plan", "daft-logical-plan", "daft-scan", + "daft-schema", "daft-table", "dashmap", "eyre", @@ -1944,7 +1947,6 @@ dependencies = [ "pyo3", "spark-connect", "tokio", - "tokio-util", "tonic", "tracing", "uuid 1.10.0", diff --git a/Cargo.toml b/Cargo.toml index 79f933dad9..be1146166a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,11 +194,14 @@ chrono-tz = "0.8.4" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +daft-core = {path = "src/daft-core"} +daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} daft-local-plan = {path = "src/daft-local-plan"} daft-logical-plan = {path = "src/daft-logical-plan"} daft-scan = {path = "src/daft-scan"} +daft-schema = {path = "src/daft-schema"} daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 368af56011..b821f58115 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1208,10 +1208,11 @@ def sql_expr(sql: str) -> PyExpr: ... def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... -def connect_start(addr: str) -> ConnectionHandle: ... +def connect_start(addr: str = "sc://0.0.0.0:0") -> ConnectionHandle: ... class ConnectionHandle: def shutdown(self) -> None: ... + def port(self) -> int: ... # expr numeric ops def abs(expr: PyExpr) -> PyExpr: ... diff --git a/src/common/py-serde/src/python.rs b/src/common/py-serde/src/python.rs index e634743f2d..10e467e931 100644 --- a/src/common/py-serde/src/python.rs +++ b/src/common/py-serde/src/python.rs @@ -49,12 +49,19 @@ impl<'de> Visitor<'de> for PyObjectVisitor { where E: DeError, { - Python::with_gil(|py| { - py.import_bound(pyo3::intern!(py, "daft.pickle")) - .and_then(|m| m.getattr(pyo3::intern!(py, "loads"))) - .and_then(|f| Ok(f.call1((v,))?.into())) - .map_err(|e| DeError::custom(e.to_string())) - }) + self.visit_bytes(&v) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut v: Vec = Vec::with_capacity(seq.size_hint().unwrap_or_default()); + while let Some(elem) = seq.next_element()? { + v.push(elem); + } + + self.visit_bytes(&v) } } diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 7955bdacd4..9651106968 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,10 +1,13 @@ [dependencies] arrow2 = {workspace = true} +async-stream = "0.3.6" common-daft-config = {workspace = true} +daft-core = {workspace = true} +daft-dsl = {workspace = true} daft-local-execution = {workspace = true} -daft-local-plan = {workspace = true} daft-logical-plan = {workspace = true} daft-scan = {workspace = true} +daft-schema = {workspace = true} daft-table = {workspace = true} dashmap = "6.1.0" eyre = "0.6.12" @@ -12,13 +15,12 @@ futures = "0.3.31" pyo3 = {workspace = true, optional = true} spark-connect = {workspace = true} tokio = {version = "1.40.0", features = ["full"]} -tokio-util = {workspace = true} tonic = "0.12.3" tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} [features] -python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-local-plan/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python"] +python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python"] [lints] workspace = true diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 95cd9ce75a..70171ad0d4 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -22,7 +22,7 @@ use spark_connect::{ ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, }; use tonic::{transport::Server, Request, Response, Status}; -use tracing::info; +use tracing::{debug, info}; use uuid::Uuid; use crate::session::Session; @@ -37,6 +37,7 @@ pub mod util; #[cfg_attr(feature = "python", pyo3::pyclass)] pub struct ConnectionHandle { shutdown_signal: Option>, + port: u16, } #[cfg_attr(feature = "python", pyo3::pymethods)] @@ -47,12 +48,19 @@ impl ConnectionHandle { }; shutdown_signal.send(()).unwrap(); } + + pub fn port(&self) -> u16 { + self.port + } } pub fn start(addr: &str) -> eyre::Result { info!("Daft-Connect server listening on {addr}"); let addr = util::parse_spark_connect_address(addr)?; + let listener = std::net::TcpListener::bind(addr)?; + let port = listener.local_addr()?.port(); + let service = DaftSparkConnectService::default(); info!("Daft-Connect server listening on {addr}"); @@ -61,25 +69,40 @@ pub fn start(addr: &str) -> eyre::Result { let handle = ConnectionHandle { shutdown_signal: Some(shutdown_signal), + port, }; std::thread::spawn(move || { let runtime = tokio::runtime::Runtime::new().unwrap(); - let result = runtime - .block_on(async { - tokio::select! { - result = Server::builder() - .add_service(SparkConnectServiceServer::new(service)) - .serve(addr) => { - result - } - _ = shutdown_receiver => { - info!("Received shutdown signal"); - Ok(()) + let result = runtime.block_on(async { + let incoming = { + let listener = tokio::net::TcpListener::from_std(listener) + .wrap_err("Failed to create TcpListener from std::net::TcpListener")?; + + async_stream::stream! { + loop { + match listener.accept().await { + Ok((stream, _)) => yield Ok(stream), + Err(e) => yield Err(e), + } } } - }) - .wrap_err_with(|| format!("Failed to start server on {addr}")); + }; + + let result = tokio::select! { + result = Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve_with_incoming(incoming)=> { + result + } + _ = shutdown_receiver => { + info!("Received shutdown signal"); + Ok(()) + } + }; + + result.wrap_err_with(|| format!("Failed to start server on {addr}")) + }); if let Err(e) = result { eprintln!("Daft-Connect server error: {e:?}"); @@ -286,22 +309,22 @@ impl SparkConnectService for DaftSparkConnectService { Ok(schema) => schema, Err(e) => { return invalid_argument_err!( - "Failed to translate relation to schema: {e}" + "Failed to translate relation to schema: {e:?}" ); } }; - let schema = analyze_plan_response::DdlParse { - parsed: Some(result), + let schema = analyze_plan_response::Schema { + schema: Some(result), }; let response = AnalyzePlanResponse { session_id, server_side_session_id: String::new(), - result: Some(analyze_plan_response::Result::DdlParse(schema)), + result: Some(analyze_plan_response::Result::Schema(schema)), }; - println!("response: {response:#?}"); + debug!("response: {response:#?}"); Ok(Response::new(response)) } @@ -363,7 +386,7 @@ impl SparkConnectService for DaftSparkConnectService { #[cfg(feature = "python")] #[pyo3::pyfunction] -#[pyo3(name = "connect_start")] +#[pyo3(name = "connect_start", signature = (addr = "sc://0.0.0.0:0"))] pub fn py_connect_start(addr: &str) -> pyo3::PyResult { start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) } diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index e0491cd17e..1e1fac147b 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,9 +1,9 @@ use std::{collections::HashMap, future::ready}; use common_daft_config::DaftExecutionConfig; +use daft_local_execution::NativeExecutor; use futures::stream; use spark_connect::{ExecutePlanResponse, Relation}; -use tokio_util::sync::CancellationToken; use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; use crate::{ @@ -28,37 +28,32 @@ impl Session { let finished = context.finished(); - let (tx, rx) = tokio::sync::mpsc::channel::>(16); - std::thread::spawn(move || { - let result = (|| -> eyre::Result<()> { + let (tx, rx) = tokio::sync::mpsc::channel::>(1); + tokio::spawn(async move { + let execution_fut = async { let plan = translation::to_logical_plan(command)?; - let logical_plan = plan.build(); - // TODO(desmond): It looks like we don't currently do optimizer passes here before translation. - let physical_plan = daft_local_plan::translate(&logical_plan)?; - + let optimized_plan = plan.optimize()?; let cfg = DaftExecutionConfig::default(); - let results = daft_local_execution::run_local( - &physical_plan, - HashMap::new(), - cfg.into(), - None, - CancellationToken::new(), // todo: maybe implement cancelling - )?; + let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; + let mut result_stream = native_executor + .run(HashMap::new(), cfg.into(), None)? + .into_stream(); - for result in results { + while let Some(result) = result_stream.next().await { let result = result?; let tables = result.get_tables()?; - for table in tables.as_slice() { let response = context.gen_response(table)?; - tx.blocking_send(Ok(response)).unwrap(); + if tx.send(Ok(response)).await.is_err() { + return Ok(()); + } } } Ok(()) - })(); + }; - if let Err(e) = result { - tx.blocking_send(Err(e)).unwrap(); + if let Err(e) = execution_fut.await { + let _ = tx.send(Err(e)).await; } }); diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index 125aa6e884..bb2d73b507 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -1,7 +1,13 @@ //! Translation between Spark Connect and Daft +mod datatype; +mod expr; +mod literal; mod logical_plan; mod schema; +pub use datatype::to_spark_datatype; +pub use expr::to_daft_expr; +pub use literal::to_daft_literal; pub use logical_plan::to_logical_plan; pub use schema::relation_to_schema; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs new file mode 100644 index 0000000000..9a40844464 --- /dev/null +++ b/src/daft-connect/src/translation/datatype.rs @@ -0,0 +1,114 @@ +use daft_schema::dtype::DataType; +use spark_connect::data_type::Kind; +use tracing::warn; + +pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { + match datatype { + DataType::Null => spark_connect::DataType { + kind: Some(Kind::Null(spark_connect::data_type::Null { + type_variation_reference: 0, + })), + }, + DataType::Boolean => spark_connect::DataType { + kind: Some(Kind::Boolean(spark_connect::data_type::Boolean { + type_variation_reference: 0, + })), + }, + DataType::Int8 => spark_connect::DataType { + kind: Some(Kind::Byte(spark_connect::data_type::Byte { + type_variation_reference: 0, + })), + }, + DataType::Int16 => spark_connect::DataType { + kind: Some(Kind::Short(spark_connect::data_type::Short { + type_variation_reference: 0, + })), + }, + DataType::Int32 => spark_connect::DataType { + kind: Some(Kind::Integer(spark_connect::data_type::Integer { + type_variation_reference: 0, + })), + }, + DataType::Int64 => spark_connect::DataType { + kind: Some(Kind::Long(spark_connect::data_type::Long { + type_variation_reference: 0, + })), + }, + DataType::UInt8 => spark_connect::DataType { + kind: Some(Kind::Byte(spark_connect::data_type::Byte { + type_variation_reference: 0, + })), + }, + DataType::UInt16 => spark_connect::DataType { + kind: Some(Kind::Short(spark_connect::data_type::Short { + type_variation_reference: 0, + })), + }, + DataType::UInt32 => spark_connect::DataType { + kind: Some(Kind::Integer(spark_connect::data_type::Integer { + type_variation_reference: 0, + })), + }, + DataType::UInt64 => spark_connect::DataType { + kind: Some(Kind::Long(spark_connect::data_type::Long { + type_variation_reference: 0, + })), + }, + DataType::Float32 => spark_connect::DataType { + kind: Some(Kind::Float(spark_connect::data_type::Float { + type_variation_reference: 0, + })), + }, + DataType::Float64 => spark_connect::DataType { + kind: Some(Kind::Double(spark_connect::data_type::Double { + type_variation_reference: 0, + })), + }, + DataType::Decimal128(precision, scale) => spark_connect::DataType { + kind: Some(Kind::Decimal(spark_connect::data_type::Decimal { + scale: Some(*scale as i32), + precision: Some(*precision as i32), + type_variation_reference: 0, + })), + }, + DataType::Timestamp(unit, _) => { + warn!("Ignoring time unit {unit:?} for timestamp type"); + spark_connect::DataType { + kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp { + type_variation_reference: 0, + })), + } + } + DataType::Date => spark_connect::DataType { + kind: Some(Kind::Date(spark_connect::data_type::Date { + type_variation_reference: 0, + })), + }, + DataType::Binary => spark_connect::DataType { + kind: Some(Kind::Binary(spark_connect::data_type::Binary { + type_variation_reference: 0, + })), + }, + DataType::Utf8 => spark_connect::DataType { + kind: Some(Kind::String(spark_connect::data_type::String { + type_variation_reference: 0, + collation: String::new(), // todo(correctness): is this correct? + })), + }, + DataType::Struct(fields) => spark_connect::DataType { + kind: Some(Kind::Struct(spark_connect::data_type::Struct { + fields: fields + .iter() + .map(|f| spark_connect::data_type::StructField { + name: f.name.clone(), + data_type: Some(to_spark_datatype(&f.dtype)), + nullable: true, // todo(correctness): is this correct? + metadata: None, + }) + .collect(), + type_variation_reference: 0, + })), + }, + _ => unimplemented!("Unsupported datatype: {datatype:?}"), + } +} diff --git a/src/daft-connect/src/translation/expr.rs b/src/daft-connect/src/translation/expr.rs new file mode 100644 index 0000000000..bcbadf9737 --- /dev/null +++ b/src/daft-connect/src/translation/expr.rs @@ -0,0 +1,105 @@ +use std::sync::Arc; + +use eyre::{bail, Context}; +use spark_connect::{expression as spark_expr, Expression}; +use tracing::warn; +use unresolved_function::unresolved_to_daft_expr; + +use crate::translation::to_daft_literal; + +mod unresolved_function; + +pub fn to_daft_expr(expression: &Expression) -> eyre::Result { + if let Some(common) = &expression.common { + warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + }; + + let Some(expr) = &expression.expr_type else { + bail!("Expression is required"); + }; + + match expr { + spark_expr::ExprType::Literal(l) => to_daft_literal(l), + spark_expr::ExprType::UnresolvedAttribute(attr) => { + let spark_expr::UnresolvedAttribute { + unparsed_identifier, + plan_id, + is_metadata_column, + } = attr; + + if let Some(plan_id) = plan_id { + warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); + } + + if let Some(is_metadata_column) = is_metadata_column { + warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); + } + + Ok(daft_dsl::col(unparsed_identifier.as_str())) + } + spark_expr::ExprType::UnresolvedFunction(f) => { + unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function") + } + spark_expr::ExprType::ExpressionString(_) => bail!("Expression string not yet supported"), + spark_expr::ExprType::UnresolvedStar(_) => { + bail!("Unresolved star expressions not yet supported") + } + spark_expr::ExprType::Alias(alias) => { + let spark_expr::Alias { + expr, + name, + metadata, + } = &**alias; + + let Some(expr) = expr else { + bail!("Alias expr is required"); + }; + + let [name] = name.as_slice() else { + bail!("Alias name is required and currently only works with a single string; got {name:?}"); + }; + + if let Some(metadata) = metadata { + bail!("Alias metadata is not yet supported; got {metadata:?}"); + } + + let child = to_daft_expr(expr)?; + + let name = Arc::from(name.as_str()); + + Ok(child.alias(name)) + } + spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"), + spark_expr::ExprType::UnresolvedRegex(_) => { + bail!("Unresolved regex expressions not yet supported") + } + spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"), + spark_expr::ExprType::LambdaFunction(_) => { + bail!("Lambda function expressions not yet supported") + } + spark_expr::ExprType::Window(_) => bail!("Window expressions not yet supported"), + spark_expr::ExprType::UnresolvedExtractValue(_) => { + bail!("Unresolved extract value expressions not yet supported") + } + spark_expr::ExprType::UpdateFields(_) => { + bail!("Update fields expressions not yet supported") + } + spark_expr::ExprType::UnresolvedNamedLambdaVariable(_) => { + bail!("Unresolved named lambda variable expressions not yet supported") + } + spark_expr::ExprType::CommonInlineUserDefinedFunction(_) => { + bail!("Common inline user defined function expressions not yet supported") + } + spark_expr::ExprType::CallFunction(_) => { + bail!("Call function expressions not yet supported") + } + spark_expr::ExprType::NamedArgumentExpression(_) => { + bail!("Named argument expressions not yet supported") + } + spark_expr::ExprType::MergeAction(_) => bail!("Merge action expressions not yet supported"), + spark_expr::ExprType::TypedAggregateExpression(_) => { + bail!("Typed aggregate expressions not yet supported") + } + spark_expr::ExprType::Extension(_) => bail!("Extension expressions not yet supported"), + } +} diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs new file mode 100644 index 0000000000..ffb8c802ce --- /dev/null +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -0,0 +1,44 @@ +use daft_core::count_mode::CountMode; +use eyre::{bail, Context}; +use spark_connect::expression::UnresolvedFunction; + +use crate::translation::to_daft_expr; + +pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result { + let UnresolvedFunction { + function_name, + arguments, + is_distinct, + is_user_defined_function, + } = f; + + let arguments: Vec<_> = arguments.iter().map(to_daft_expr).try_collect()?; + + if *is_distinct { + bail!("Distinct not yet supported"); + } + + if *is_user_defined_function { + bail!("User-defined functions not yet supported"); + } + + match function_name.as_str() { + "count" => handle_count(arguments).wrap_err("Failed to handle count function"), + n => bail!("Unresolved function {n} not yet supported"), + } +} + +pub fn handle_count(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + + let count = arg.count(CountMode::All); + + Ok(count) +} diff --git a/src/daft-connect/src/translation/literal.rs b/src/daft-connect/src/translation/literal.rs new file mode 100644 index 0000000000..f6a26db84a --- /dev/null +++ b/src/daft-connect/src/translation/literal.rs @@ -0,0 +1,52 @@ +use daft_core::datatypes::IntervalValue; +use eyre::bail; +use spark_connect::expression::{literal::LiteralType, Literal}; + +// todo(test): add tests for this esp in Python +pub fn to_daft_literal(literal: &Literal) -> eyre::Result { + let Some(literal) = &literal.literal_type else { + bail!("Literal is required"); + }; + + match literal { + LiteralType::Array(_) => bail!("Array literals not yet supported"), + LiteralType::Binary(bytes) => Ok(daft_dsl::lit(bytes.as_slice())), + LiteralType::Boolean(b) => Ok(daft_dsl::lit(*b)), + LiteralType::Byte(_) => bail!("Byte literals not yet supported"), + LiteralType::CalendarInterval(_) => { + bail!("Calendar interval literals not yet supported") + } + LiteralType::Date(d) => Ok(daft_dsl::lit(*d)), + LiteralType::DayTimeInterval(_) => { + bail!("Day-time interval literals not yet supported") + } + LiteralType::Decimal(_) => bail!("Decimal literals not yet supported"), + LiteralType::Double(d) => Ok(daft_dsl::lit(*d)), + LiteralType::Float(f) => { + let f = f64::from(*f); + Ok(daft_dsl::lit(f)) + } + LiteralType::Integer(i) => Ok(daft_dsl::lit(*i)), + LiteralType::Long(l) => Ok(daft_dsl::lit(*l)), + LiteralType::Map(_) => bail!("Map literals not yet supported"), + LiteralType::Null(_) => { + // todo(correctness): is it ok to assume type is i32 here? + Ok(daft_dsl::null_lit()) + } + LiteralType::Short(_) => bail!("Short literals not yet supported"), + LiteralType::String(s) => Ok(daft_dsl::lit(s.as_str())), + LiteralType::Struct(_) => bail!("Struct literals not yet supported"), + LiteralType::Timestamp(ts) => { + // todo(correctness): is it ok that the type is different logically? + Ok(daft_dsl::lit(*ts)) + } + LiteralType::TimestampNtz(ts) => { + // todo(correctness): is it ok that the type is different logically? + Ok(daft_dsl::lit(*ts)) + } + LiteralType::YearMonthInterval(value) => { + let interval = IntervalValue::new(*value, 0, 0); + Ok(daft_dsl::lit(interval)) + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 58255e2ef9..947e0cd0d3 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,8 +1,14 @@ use daft_logical_plan::LogicalPlanBuilder; -use eyre::{bail, ensure, Context}; -use spark_connect::{relation::RelType, Range, Relation}; +use eyre::{bail, Context}; +use spark_connect::{relation::RelType, Relation}; use tracing::warn; +use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; + +mod aggregate; +mod project; +mod range; + pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); @@ -14,55 +20,10 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { match rel_type { RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), + RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"), + RelType::Aggregate(a) => { + aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") + } plan => bail!("Unsupported relation type: {plan:?}"), } } - -fn range(range: Range) -> eyre::Result { - #[cfg(not(feature = "python"))] - bail!("Range operations require Python feature to be enabled"); - - #[cfg(feature = "python")] - { - use daft_scan::python::pylib::ScanOperatorHandle; - use pyo3::prelude::*; - let Range { - start, - end, - step, - num_partitions, - } = range; - - let partitions = num_partitions.unwrap_or(1); - - ensure!(partitions > 0, "num_partitions must be greater than 0"); - - let start = start.unwrap_or(0); - - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); - - let plan = Python::with_gil(|py| { - let range_module = PyModule::import_bound(py, "daft.io._range") - .wrap_err("Failed to import range module")?; - - let range = range_module - .getattr(pyo3::intern!(py, "RangeScanOperator")) - .wrap_err("Failed to get range function")?; - - let range = range - .call1((start, end, step, partitions)) - .wrap_err("Failed to create range scan operator")? - .to_object(py); - - let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; - - let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; - - eyre::Result::<_>::Ok(plan) - }) - .wrap_err("Failed to create range scan")?; - - Ok(plan) - } -} diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs new file mode 100644 index 0000000000..193ca4d088 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs @@ -0,0 +1,72 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, WrapErr}; +use spark_connect::aggregate::GroupType; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { + let spark_connect::Aggregate { + input, + group_type, + grouping_expressions, + aggregate_expressions, + pivot, + grouping_sets, + } = aggregate; + + let Some(input) = input else { + bail!("input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let group_type = GroupType::try_from(group_type) + .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?; + + assert_groupby(group_type)?; + + if let Some(pivot) = pivot { + bail!("Pivot not yet supported; got {pivot:?}"); + } + + if !grouping_sets.is_empty() { + bail!("Grouping sets not yet supported; got {grouping_sets:?}"); + } + + let grouping_expressions: Vec<_> = grouping_expressions + .iter() + .map(to_daft_expr) + .try_collect()?; + + let aggregate_expressions: Vec<_> = aggregate_expressions + .iter() + .map(to_daft_expr) + .try_collect()?; + + let plan = plan + .aggregate(aggregate_expressions.clone(), grouping_expressions.clone()) + .wrap_err_with(|| format!("Failed to apply aggregate to logical plan aggregate_expressions={aggregate_expressions:?} grouping_expressions={grouping_expressions:?}"))?; + + Ok(plan) +} + +fn assert_groupby(plan: GroupType) -> eyre::Result<()> { + match plan { + GroupType::Unspecified => { + bail!("GroupType must be specified; got Unspecified") + } + GroupType::Groupby => Ok(()), + GroupType::Rollup => { + bail!("Rollup not yet supported") + } + GroupType::Cube => { + bail!("Cube not yet supported") + } + GroupType::Pivot => { + bail!("Pivot not yet supported") + } + GroupType::GroupingSets => { + bail!("GroupingSets not yet supported") + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs new file mode 100644 index 0000000000..3096b7f313 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/project.rs @@ -0,0 +1,26 @@ +//! Project operation for selecting and manipulating columns from a dataset +//! +//! TL;DR: Project is Spark's equivalent of SQL SELECT - it selects columns, renames them via aliases, +//! and creates new columns from expressions. Example: `df.select(col("id").alias("my_number"))` + +use daft_logical_plan::LogicalPlanBuilder; +use eyre::bail; +use spark_connect::Project; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn project(project: Project) -> eyre::Result { + let Project { input, expressions } = project; + + let Some(input) = input else { + bail!("Project input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; + + let plan = plan.select(daft_exprs)?; + + Ok(plan) +} diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs new file mode 100644 index 0000000000..e11fef26cb --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/range.rs @@ -0,0 +1,55 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{ensure, Context}; +use spark_connect::Range; + +pub fn range(range: Range) -> eyre::Result { + #[cfg(not(feature = "python"))] + { + use eyre::bail; + bail!("Range operations require Python feature to be enabled"); + } + + #[cfg(feature = "python")] + { + use daft_scan::python::pylib::ScanOperatorHandle; + use pyo3::prelude::*; + let Range { + start, + end, + step, + num_partitions, + } = range; + + let partitions = num_partitions.unwrap_or(1); + + ensure!(partitions > 0, "num_partitions must be greater than 0"); + + let start = start.unwrap_or(0); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let plan = Python::with_gil(|py| { + let range_module = PyModule::import_bound(py, "daft.io._range") + .wrap_err("Failed to import range module")?; + + let range = range_module + .getattr(pyo3::intern!(py, "RangeScanOperator")) + .wrap_err("Failed to get range function")?; + + let range = range + .call1((start, end, step, partitions)) + .wrap_err("Failed to create range scan operator")? + .to_object(py); + + let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; + + let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; + + eyre::Result::<_>::Ok(plan) + }) + .wrap_err("Failed to create range scan")?; + + Ok(plan) + } +} diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index de28a587fc..1b242428d2 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -1,54 +1,39 @@ -use eyre::bail; use spark_connect::{ - data_type::{Kind, Long, Struct, StructField}, - relation::RelType, + data_type::{Kind, Struct, StructField}, DataType, Relation, }; use tracing::warn; +use crate::translation::{to_logical_plan, to_spark_datatype}; + #[tracing::instrument(skip_all)] pub fn relation_to_schema(input: Relation) -> eyre::Result { if input.common.is_some() { warn!("We do not currently look at common fields"); } - let result = match input - .rel_type - .ok_or_else(|| tonic::Status::internal("rel_type is None"))? - { - RelType::Range(spark_connect::Range { num_partitions, .. }) => { - if num_partitions.is_some() { - warn!("We do not currently support num_partitions"); - } - - let long = Long { - type_variation_reference: 0, - }; - - let id_field = StructField { - name: "id".to_string(), - data_type: Some(DataType { - kind: Some(Kind::Long(long)), - }), - nullable: false, - metadata: None, - }; - - let fields = vec![id_field]; - - let strct = Struct { - fields, - type_variation_reference: 0, - }; - - DataType { - kind: Some(Kind::Struct(strct)), - } - } - other => { - bail!("Unsupported relation type: {other:?}"); - } - }; - - Ok(result) + let plan = to_logical_plan(input)?; + + let result = plan.schema(); + + let fields: eyre::Result> = result + .fields + .iter() + .map(|(name, field)| { + let field_type = to_spark_datatype(&field.dtype); + Ok(StructField { + name: name.clone(), // todo(correctness): name vs field.name... will they always be the same? + data_type: Some(field_type), + nullable: true, // todo(correctness): is this correct? + metadata: None, // todo(completeness): might want to add metadata here + }) + }) + .collect(); + + Ok(DataType { + kind: Some(Kind::Struct(Struct { + fields: fields?, + type_variation_reference: 0, + })), + }) } diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index c9a2b3f2ac..7885c7fd83 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -128,6 +128,11 @@ impl<'a> InferDataType<'a> { Ok((DataType::Boolean, Some(d_type.clone()), d_type)) } + + (DataType::Utf8, DataType::Date) | (DataType::Date, DataType::Utf8) => { + // Date is logical, so we cast to intermediate type (date), then compare on the physical type (i32) + Ok((DataType::Boolean, Some(DataType::Date), DataType::Int32)) + } (s, o) if s.is_physical() && o.is_physical() => { Ok((DataType::Boolean, None, try_physical_supertype(s, o)?)) } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 8f0fba1fec..1d86442aef 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -366,6 +366,12 @@ pub trait Literal: Sized { fn literal_value(self) -> LiteralValue; } +impl Literal for IntervalValue { + fn literal_value(self) -> LiteralValue { + LiteralValue::Interval(self) + } +} + impl Literal for String { fn literal_value(self) -> LiteralValue { LiteralValue::Utf8(self) diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index 7a58e79ade..8adaae0616 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -16,6 +16,10 @@ impl Receiver { pub(crate) fn blocking_recv(&self) -> Option { self.0.recv().ok() } + + pub(crate) fn into_inner(self) -> loole::Receiver { + self.0 + } } pub(crate) fn create_channel(buffer_size: usize) -> (Sender, Receiver) { diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index bda9bfcd09..df22857519 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -18,7 +18,7 @@ use std::{ use common_error::{DaftError, DaftResult}; use common_runtime::RuntimeTask; use lazy_static::lazy_static; -pub use run::{run_local, NativeExecutor}; +pub use run::{run_local, ExecutionEngineResult, NativeExecutor}; use snafu::{futures::TryFutureExt, ResultExt, Snafu}; lazy_static! { @@ -200,6 +200,8 @@ type Result = std::result::Result; #[cfg(feature = "python")] pub fn register_modules(parent: &Bound) -> PyResult<()> { - parent.add_class::()?; + use run::PyNativeExecutor; + + parent.add_class::()?; Ok(()) } diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 039df990cd..caf69f156d 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -10,7 +10,10 @@ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_tracing::refresh_chrome_trace; use daft_local_plan::{translate, LocalPhysicalPlan}; +use daft_logical_plan::LogicalPlanBuilder; use daft_micropartition::MicroPartition; +use futures::{FutureExt, Stream}; +use loole::RecvFuture; use tokio_util::sync::CancellationToken; #[cfg(feature = "python")] use { @@ -44,32 +47,25 @@ impl LocalPartitionIterator { } } -#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub struct NativeExecutor { - local_physical_plan: Arc, - cancel: CancellationToken, -} - -impl Drop for NativeExecutor { - fn drop(&mut self) { - self.cancel.cancel(); - } +#[cfg_attr( + feature = "python", + pyclass(module = "daft.daft", name = "NativeExecutor") +)] +pub struct PyNativeExecutor { + executor: NativeExecutor, } #[cfg(feature = "python")] #[pymethods] -impl NativeExecutor { +impl PyNativeExecutor { #[staticmethod] pub fn from_logical_plan_builder( logical_plan_builder: &PyLogicalPlanBuilder, py: Python, ) -> PyResult { py.allow_threads(|| { - let logical_plan = logical_plan_builder.builder.build(); - let local_physical_plan = translate(&logical_plan)?; Ok(Self { - local_physical_plan, - cancel: CancellationToken::new(), + executor: NativeExecutor::from_logical_plan_builder(&logical_plan_builder.builder)?, }) }) } @@ -94,13 +90,9 @@ impl NativeExecutor { }) .collect(); let out = py.allow_threads(|| { - run_local( - &self.local_physical_plan, - native_psets, - cfg.config, - results_buffer_size, - self.cancel.clone(), - ) + self.executor + .run(native_psets, cfg.config, results_buffer_size) + .map(|res| res.into_iter()) })?; let iter = Box::new(out.map(|part| { part.map(|p| pyo3::Python::with_gil(|py| PyMicroPartition::from(p).into_py(py))) @@ -110,6 +102,45 @@ impl NativeExecutor { } } +pub struct NativeExecutor { + local_physical_plan: Arc, + cancel: CancellationToken, +} + +impl NativeExecutor { + pub fn from_logical_plan_builder( + logical_plan_builder: &LogicalPlanBuilder, + ) -> DaftResult { + let logical_plan = logical_plan_builder.build(); + let local_physical_plan = translate(&logical_plan)?; + Ok(Self { + local_physical_plan, + cancel: CancellationToken::new(), + }) + } + + pub fn run( + &self, + psets: HashMap>>, + cfg: Arc, + results_buffer_size: Option, + ) -> DaftResult { + run_local( + &self.local_physical_plan, + psets, + cfg, + results_buffer_size, + self.cancel.clone(), + ) + } +} + +impl Drop for NativeExecutor { + fn drop(&mut self) { + self.cancel.cancel(); + } +} + fn should_enable_explain_analyze() -> bool { let explain_var_name = "DAFT_DEV_ENABLE_EXPLAIN_ANALYZE"; if let Ok(val) = std::env::var(explain_var_name) @@ -121,13 +152,105 @@ fn should_enable_explain_analyze() -> bool { } } +pub struct ExecutionEngineReceiverIterator { + receiver: Receiver>, + handle: Option>>, +} + +impl Iterator for ExecutionEngineReceiverIterator { + type Item = DaftResult>; + + fn next(&mut self) -> Option { + match self.receiver.blocking_recv() { + Some(part) => Some(Ok(part)), + None => { + if self.handle.is_some() { + let join_result = self + .handle + .take() + .unwrap() + .join() + .expect("Execution engine thread panicked"); + match join_result { + Ok(()) => None, + Err(e) => Some(Err(e)), + } + } else { + None + } + } + } + } +} + +pub struct ExecutionEngineReceiverStream { + receive_fut: RecvFuture>, + handle: Option>>, +} + +impl Stream for ExecutionEngineReceiverStream { + type Item = DaftResult>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.receive_fut.poll_unpin(cx) { + std::task::Poll::Ready(Ok(part)) => std::task::Poll::Ready(Some(Ok(part))), + std::task::Poll::Ready(Err(_)) => { + if self.handle.is_some() { + let join_result = self + .handle + .take() + .unwrap() + .join() + .expect("Execution engine thread panicked"); + match join_result { + Ok(()) => std::task::Poll::Ready(None), + Err(e) => std::task::Poll::Ready(Some(Err(e))), + } + } else { + std::task::Poll::Ready(None) + } + } + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} + +pub struct ExecutionEngineResult { + handle: std::thread::JoinHandle>, + receiver: Receiver>, +} + +impl ExecutionEngineResult { + pub fn into_stream(self) -> impl Stream>> { + ExecutionEngineReceiverStream { + receive_fut: self.receiver.into_inner().recv_async(), + handle: Some(self.handle), + } + } +} + +impl IntoIterator for ExecutionEngineResult { + type Item = DaftResult>; + type IntoIter = ExecutionEngineReceiverIterator; + + fn into_iter(self) -> Self::IntoIter { + ExecutionEngineReceiverIterator { + receiver: self.receiver, + handle: Some(self.handle), + } + } +} + pub fn run_local( physical_plan: &LocalPhysicalPlan, psets: HashMap>>, cfg: Arc, results_buffer_size: Option, cancel: CancellationToken, -) -> DaftResult>> + Send>> { +) -> DaftResult { refresh_chrome_trace(); let pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); @@ -188,38 +311,8 @@ pub fn run_local( }) }); - struct ReceiverIterator { - receiver: Receiver>, - handle: Option>>, - } - - impl Iterator for ReceiverIterator { - type Item = DaftResult>; - - fn next(&mut self) -> Option { - match self.receiver.blocking_recv() { - Some(part) => Some(Ok(part)), - None => { - if self.handle.is_some() { - let join_result = self - .handle - .take() - .unwrap() - .join() - .expect("Execution engine thread panicked"); - match join_result { - Ok(()) => None, - Err(e) => Some(Err(e)), - } - } else { - None - } - } - } - } - } - Ok(Box::new(ReceiverIterator { + Ok(ExecutionEngineResult { + handle, receiver: rx, - handle: Some(handle), - })) + }) } diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index c251537ff2..c945c80203 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -589,6 +589,46 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } + pub fn optimize(&self) -> DaftResult { + let default_optimizer_config: OptimizerConfig = Default::default(); + let optimizer_config = OptimizerConfig { + enable_actor_pool_projections: self + .config + .as_ref() + .map(|planning_cfg| planning_cfg.enable_actor_pool_projections) + .unwrap_or(default_optimizer_config.enable_actor_pool_projections), + ..default_optimizer_config + }; + let optimizer = Optimizer::new(optimizer_config); + + // Run LogicalPlan optimizations + let unoptimized_plan = self.build(); + let optimized_plan = optimizer.optimize( + unoptimized_plan, + |new_plan, rule_batch, pass, transformed, seen| { + if transformed { + log::debug!( + "Rule batch {:?} transformed plan on pass {}, and produced {} plan:\n{}", + rule_batch, + pass, + if seen { "an already seen" } else { "a new" }, + new_plan.repr_ascii(true), + ); + } else { + log::debug!( + "Rule batch {:?} did NOT transform plan on pass {} for plan:\n{}", + rule_batch, + pass, + new_plan.repr_ascii(true), + ); + } + }, + )?; + + let builder = Self::new(optimized_plan, self.config.clone()); + Ok(builder) + } + pub fn build(&self) -> Arc { self.plan.clone() } @@ -918,39 +958,7 @@ impl PyLogicalPlanBuilder { /// Optimize the underlying logical plan, returning a new plan builder containing the optimized plan. pub fn optimize(&self, py: Python) -> PyResult { - py.allow_threads(|| { - // Create optimizer - let default_optimizer_config: OptimizerConfig = Default::default(); - let optimizer_config = OptimizerConfig { enable_actor_pool_projections: self.builder.config.as_ref().map(|planning_cfg| planning_cfg.enable_actor_pool_projections).unwrap_or(default_optimizer_config.enable_actor_pool_projections), ..default_optimizer_config }; - let optimizer = Optimizer::new(optimizer_config); - - // Run LogicalPlan optimizations - let unoptimized_plan = self.builder.build(); - let optimized_plan = optimizer.optimize( - unoptimized_plan, - |new_plan, rule_batch, pass, transformed, seen| { - if transformed { - log::debug!( - "Rule batch {:?} transformed plan on pass {}, and produced {} plan:\n{}", - rule_batch, - pass, - if seen { "an already seen" } else { "a new" }, - new_plan.repr_ascii(true), - ); - } else { - log::debug!( - "Rule batch {:?} did NOT transform plan on pass {} for plan:\n{}", - rule_batch, - pass, - new_plan.repr_ascii(true), - ); - } - }, - )?; - - let builder = LogicalPlanBuilder::new(optimized_plan, self.builder.config.clone()); - Ok(builder.into()) - }) + py.allow_threads(|| Ok(self.builder.optimize()?.into())) } pub fn repr_ascii(&self, simple: bool) -> PyResult { diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 28148d8bbf..758d8c1920 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -72,6 +72,7 @@ mod tests { Schema::new(vec![ Field::new("text", DataType::Utf8), Field::new("id", DataType::Int32), + Field::new("val", DataType::Int32), ]) .unwrap(), ); @@ -138,6 +139,7 @@ mod tests { #[case::slice("select list_utf8[0:2] from tbl1")] #[case::join("select * from tbl2 join tbl3 on tbl2.id = tbl3.id")] #[case::null_safe_join("select * from tbl2 left join tbl3 on tbl2.id <=> tbl3.id")] + #[case::join_with_filter("select * from tbl2 join tbl3 on tbl2.id = tbl3.id and tbl2.val > 0")] #[case::from("select tbl2.text from tbl2")] #[case::using("select tbl2.text from tbl2 join tbl3 using (id)")] #[case( @@ -301,6 +303,34 @@ mod tests { Ok(()) } + #[rstest] + fn test_join_with_filter( + mut planner: SQLPlanner, + tbl_2: LogicalPlanRef, + tbl_3: LogicalPlanRef, + ) -> SQLPlannerResult<()> { + let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id and tbl2.val > 0"; + let plan = planner.plan_sql(&sql)?; + + let expected = LogicalPlanBuilder::new(tbl_2, None) + .filter(col("val").gt(lit(0 as i64)))? + .join_with_null_safe_equal( + tbl_3, + vec![col("id")], + vec![col("id")], + Some(vec![false]), + JoinType::Inner, + None, + None, + Some("tbl3."), + true, + )? + .select(vec![col("*")])? + .build(); + assert_eq!(plan, expected); + Ok(()) + } + #[rstest] #[case::abs("select abs(i32) as abs from tbl1")] #[case::ceil("select ceil(i32) as ceil from tbl1")] diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 1eb1169b5a..a8e7a90c7b 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1,5 +1,4 @@ use std::{ - borrow::Cow, cell::{Ref, RefCell, RefMut}, collections::{HashMap, HashSet}, rc::Rc, @@ -11,8 +10,9 @@ use daft_core::prelude::*; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode}, - has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, - OuterReferenceColumn, Subquery, + has_agg, lit, literals_to_series, null_lit, + optimization::conjuct, + AggExpr, Expr, ExprRef, LiteralValue, Operator, OuterReferenceColumn, Subquery, }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, @@ -297,9 +297,8 @@ impl<'a> SQLPlanner<'a> { // FROM/JOIN let from = selection.clone().from; - let rel = self.plan_from(&from)?; - let schema = rel.schema(); - self.current_relation = Some(rel); + self.plan_from(&from)?; + let schema = self.relation_opt().unwrap().schema(); // SELECT let mut projections = Vec::with_capacity(selection.projection.len()); @@ -347,6 +346,12 @@ impl<'a> SQLPlanner<'a> { let has_aggs = projections.iter().any(has_agg) || !groupby_exprs.is_empty(); if has_aggs { + let having = selection + .having + .as_ref() + .map(|h| self.plan_expr(h)) + .transpose()?; + self.plan_aggregate_query( &projections, &schema, @@ -354,6 +359,7 @@ impl<'a> SQLPlanner<'a> { groupby_exprs, query, &projection_schema, + having, )?; } else { self.plan_non_agg_query(projections, schema, has_orderby, query, projection_schema)?; @@ -464,6 +470,7 @@ impl<'a> SQLPlanner<'a> { Ok(()) } + #[allow(clippy::too_many_arguments)] fn plan_aggregate_query( &mut self, projections: &Vec>, @@ -472,6 +479,7 @@ impl<'a> SQLPlanner<'a> { groupby_exprs: Vec>, query: &Query, projection_schema: &Schema, + having: Option>, ) -> Result<(), PlannerError> { let mut final_projection = Vec::with_capacity(projections.len()); let mut aggs = Vec::with_capacity(projections.len()); @@ -500,6 +508,15 @@ impl<'a> SQLPlanner<'a> { final_projection.push(p.clone()); } } + + if let Some(having) = &having { + if has_agg(having) { + let having = having.alias(having.semantic_id(schema).id); + + aggs.push(having); + } + } + let groupby_exprs = groupby_exprs .into_iter() .map(|e| { @@ -631,7 +648,7 @@ impl<'a> SQLPlanner<'a> { } let rel = self.relation_mut(); - rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?; + rel.inner = rel.inner.aggregate(aggs.clone(), groupby_exprs)?; let has_orderby_before_projection = !orderbys_before_projection.is_empty(); let has_orderby_after_projection = !orderbys_after_projection.is_empty(); @@ -650,6 +667,16 @@ impl<'a> SQLPlanner<'a> { )?; } + if let Some(having) = having { + // if it's an agg, it's already resolved during .agg, so we just reference the column name + let having = if has_agg(&having) { + col(having.semantic_id(schema).id) + } else { + having + }; + rel.inner = rel.inner.filter(having)?; + } + // apply the final projection rel.inner = rel.inner.select(final_projection)?; @@ -661,6 +688,7 @@ impl<'a> SQLPlanner<'a> { orderbys_after_projection_nulls_first, )?; } + Ok(()) } @@ -727,15 +755,17 @@ impl<'a> SQLPlanner<'a> { Ok((exprs, desc, nulls_first)) } - fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult { + /// Plans the FROM clause of a query and populates self.current_relation and self.table_map + /// Should only be called once per query. + fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult<()> { if from.len() > 1 { let mut from_iter = from.iter(); let first = from_iter.next().unwrap(); - let mut rel = self.new_with_context().plan_relation(&first.relation)?; + let mut rel = self.plan_relation(&first.relation)?; self.table_map.insert(rel.get_name(), rel.clone()); for tbl in from_iter { - let right = self.new_with_context().plan_relation(&tbl.relation)?; + let right = self.plan_relation(&tbl.relation)?; self.table_map.insert(right.get_name(), right.clone()); let right_join_prefix = Some(format!("{}.", right.get_name())); @@ -743,129 +773,145 @@ impl<'a> SQLPlanner<'a> { rel.inner .cross_join(right.inner, None, right_join_prefix.as_deref())?; } - return Ok(rel); + self.current_relation = Some(rel); + return Ok(()); } let from = from.iter().next().unwrap(); - fn collect_idents( - left: &[Ident], - right: &[Ident], - left_rel: &Relation, - right_rel: &Relation, - ) -> SQLPlannerResult<(Vec, Vec)> { - let (left, right) = match (left, right) { - // both are fully qualified: `join on a.x = b.y` - ([tbl_a, Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => { - if left_rel.get_name() == tbl_b.value && right_rel.get_name() == tbl_a.value { - (col_b.clone(), col_a.clone()) - } else { - (col_a.clone(), col_b.clone()) - } + macro_rules! return_non_ident_errors { + ($e:expr) => { + if !matches!( + $e, + PlannerError::ColumnNotFound { .. } | PlannerError::TableNotFound { .. } + ) { + return Err($e); } - // only one is fully qualified: `join on x = b.y` - ([Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => { - if tbl_b.value == right_rel.get_name() { - (col_a.clone(), col_b.clone()) - } else if tbl_b.value == left_rel.get_name() { - (col_b.clone(), col_a.clone()) - } else { - unsupported_sql_err!("Could not determine which table the identifiers belong to") - } + }; + } + + #[allow(clippy::too_many_arguments)] + fn process_join_on( + sql_expr: &sqlparser::ast::Expr, + left_planner: &SQLPlanner, + right_planner: &SQLPlanner, + left_on: &mut Vec, + right_on: &mut Vec, + null_eq_nulls: &mut Vec, + left_filters: &mut Vec, + right_filters: &mut Vec, + ) -> SQLPlannerResult<()> { + // check if join expression is actually a filter on one of the tables + match ( + left_planner.plan_expr(sql_expr), + right_planner.plan_expr(sql_expr), + ) { + (Ok(_), Ok(_)) => { + return Err(PlannerError::invalid_operation(format!( + "Ambiguous reference to column name in join: {}", + sql_expr + ))); } - // only one is fully qualified: `join on a.x = y` - ([tbl_a, Ident{value: col_a, ..}], [Ident{value: col_b, ..}]) => { - // find out which one the qualified identifier belongs to - // we assume the other identifier belongs to the other table - if tbl_a.value == left_rel.get_name() { - (col_a.clone(), col_b.clone()) - } else if tbl_a.value == right_rel.get_name() { - (col_b.clone(), col_a.clone()) - } else { - unsupported_sql_err!("Could not determine which table the identifiers belong to") - } + (Ok(expr), _) => { + left_filters.push(expr); + return Ok(()); } - // neither are fully qualified: `join on x = y` - ([left], [right]) => { - let left = ident_to_str(left); - let right = ident_to_str(right); - - // we don't know which table the identifiers belong to, so we need to check both - let left_schema = left_rel.schema(); - let right_schema = right_rel.schema(); - - // if the left side is in the left schema, then we assume the right side is in the right schema - if left_schema.get_field(&left).is_ok() { - (left, right) - // if the right side is in the left schema, then we assume the left side is in the right schema - } else if right_schema.get_field(&left).is_ok() { - (right, left) - } else { - unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left); - } + (_, Ok(expr)) => { + right_filters.push(expr); + return Ok(()); + } + (Err(left_err), Err(right_err)) => { + return_non_ident_errors!(left_err); + return_non_ident_errors!(right_err); + } + } + match sql_expr { + // join key + sqlparser::ast::Expr::BinaryOp { + left, + right, + op: op @ BinaryOperator::Eq, } - _ => unsupported_sql_err!( - "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", - left.len(), - right.len() - ), - }; - Ok((vec![col(left)], vec![col(right)])) - } + | sqlparser::ast::Expr::BinaryOp { + left, + right, + op: op @ BinaryOperator::Spaceship, + } => { + let null_equals_null = *op == BinaryOperator::Spaceship; - fn process_join_on( - expression: &sqlparser::ast::Expr, - left_rel: &Relation, - right_rel: &Relation, - ) -> SQLPlannerResult<(Vec, Vec, Vec)> { - if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression { - match *op { - BinaryOperator::Eq | BinaryOperator::Spaceship => { - let null_equals_null = *op == BinaryOperator::Spaceship; - - let left = get_idents_vec(left)?; - let right = get_idents_vec(right)?; - - collect_idents(&left, &right, left_rel, right_rel) - .map(|(left, right)| (left, right, vec![null_equals_null])) - } - BinaryOperator::And => { - let (mut left_i, mut right_i, mut null_equals_nulls_i) = - process_join_on(left, left_rel, right_rel)?; - let (mut left_j, mut right_j, mut null_equals_nulls_j) = - process_join_on(right, left_rel, right_rel)?; - left_i.append(&mut left_j); - right_i.append(&mut right_j); - null_equals_nulls_i.append(&mut null_equals_nulls_j); - Ok((left_i, right_i, null_equals_nulls_i)) - } - _ => { - unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{}'", op); + let mut last_error = None; + + for (left, right) in [(left, right), (right, left)] { + let left_expr = left_planner.plan_expr(left); + let right_expr = right_planner.plan_expr(right); + + if let Ok(left_expr) = &left_expr && let Ok(right_expr) = &right_expr { + left_on.push(left_expr.clone()); + right_on.push(right_expr.clone()); + null_eq_nulls.push(null_equals_null); + + return Ok(()) + } + + for expr_result in [left_expr, right_expr] { + if let Err(e) = expr_result { + return_non_ident_errors!(e); + + last_error = Some(e); + } + } } + + Err(last_error.unwrap()) } - } else if let sqlparser::ast::Expr::Nested(expr) = expression { - process_join_on(expr, left_rel, right_rel) - } else { - unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression); + // multiple expressions + sqlparser::ast::Expr::BinaryOp { + left, + right, + op: BinaryOperator::And, + } => { + process_join_on(left, left_planner, right_planner, left_on, right_on, null_eq_nulls, left_filters, right_filters)?; + process_join_on(right, left_planner, right_planner, left_on, right_on, null_eq_nulls, left_filters, right_filters)?; + + Ok(()) + } + // nested expression + sqlparser::ast::Expr::Nested(expr) => process_join_on( + expr, + left_planner, + right_planner, + left_on, + right_on, + null_eq_nulls, + left_filters, + right_filters, + ), + _ => unsupported_sql_err!("JOIN clauses support '=' constraints and filter predicates combined with 'AND'; found expression = {:?}", sql_expr) } } let relation = from.relation.clone(); - let mut left_rel = self.new_with_context().plan_relation(&relation)?; - self.table_map.insert(left_rel.get_name(), left_rel.clone()); + let left_rel = self.plan_relation(&relation)?; + self.current_relation = Some(left_rel.clone()); + self.table_map.insert(left_rel.get_name(), left_rel); for join in &from.joins { use sqlparser::ast::{ JoinConstraint, JoinOperator::{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}, }; - let right_rel = self.new_with_context().plan_relation(&join.relation)?; - self.table_map - .insert(right_rel.get_name(), right_rel.clone()); + let right_rel = self.plan_relation(&join.relation)?; let right_rel_name = right_rel.get_name(); let right_join_prefix = Some(format!("{right_rel_name}.")); + // construct a planner with the right table to use for expr planning + let mut right_planner = self.new_with_context(); + right_planner.current_relation = Some(right_rel.clone()); + right_planner + .table_map + .insert(right_rel.get_name(), right_rel.clone()); + let (join_type, constraint) = match &join.join_operator { Inner(constraint) => (JoinType::Inner, constraint), LeftOuter(constraint) => (JoinType::Left, constraint), @@ -877,37 +923,66 @@ impl<'a> SQLPlanner<'a> { _ => unsupported_sql_err!("Unsupported join type: {:?}", join.join_operator), }; - let (left_on, right_on, null_eq_null, keep_join_keys) = match &constraint { + let mut left_on = Vec::new(); + let mut right_on = Vec::new(); + let mut left_filters = Vec::new(); + let mut right_filters = Vec::new(); + + let (keep_join_keys, null_eq_nulls) = match &constraint { JoinConstraint::On(expr) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - (left_on, right_on, Some(null_equals_nulls), true) + let mut null_eq_nulls = Vec::new(); + + process_join_on( + expr, + self, + &right_planner, + &mut left_on, + &mut right_on, + &mut null_eq_nulls, + &mut left_filters, + &mut right_filters, + )?; + + (true, Some(null_eq_nulls)) } JoinConstraint::Using(idents) => { - let on = idents + left_on = idents .iter() .map(|i| col(i.value.clone())) .collect::>(); - (on.clone(), on, None, false) + right_on.clone_from(&left_on); + + (false, None) } JoinConstraint::Natural => unsupported_sql_err!("NATURAL JOIN not supported"), JoinConstraint::None => unsupported_sql_err!("JOIN without ON/USING not supported"), }; - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, + let mut left_plan = self.current_relation.as_ref().unwrap().inner.clone(); + if let Some(left_predicate) = conjuct(left_filters) { + left_plan = left_plan.filter(left_predicate)?; + } + + let mut right_plan = right_rel.inner.clone(); + if let Some(right_predicate) = conjuct(right_filters) { + right_plan = right_plan.filter(right_predicate)?; + } + + self.relation_mut().inner = left_plan.join_with_null_safe_equal( + right_plan, left_on, right_on, - null_eq_null, + null_eq_nulls, join_type, None, None, right_join_prefix.as_deref(), keep_join_keys, )?; + self.table_map.insert(right_rel_name, right_rel); } - Ok(left_rel) + Ok(()) } fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { @@ -1999,9 +2074,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult if !selection.sort_by.is_empty() { unsupported_sql_err!("SORT BY"); } - if selection.having.is_some() { - unsupported_sql_err!("HAVING"); - } + if !selection.named_window.is_empty() { unsupported_sql_err!("WINDOW"); } @@ -2078,14 +2151,6 @@ fn idents_to_str(idents: &[Ident]) -> String { .join(".") } -fn get_idents_vec(expr: &sqlparser::ast::Expr) -> SQLPlannerResult>> { - match expr { - sqlparser::ast::Expr::Identifier(ident) => Ok(Cow::Owned(vec![ident.clone()])), - sqlparser::ast::Expr::CompoundIdentifier(idents) => Ok(Cow::Borrowed(idents)), - _ => invalid_operation_err!("expected an identifier"), - } -} - /// unresolves an alias in a projection /// Example: /// ```sql diff --git a/tests/connect/__init__.py b/tests/connect/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/connect/conftest.py b/tests/connect/conftest.py new file mode 100644 index 0000000000..60c5ae9986 --- /dev/null +++ b/tests/connect/conftest.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture(scope="session") +def spark_session(): + """ + Fixture to create and clean up a Spark session. + + This fixture is available to all test files and creates a single + Spark session for the entire test suite run. + """ + from daft.daft import connect_start + + # Start Daft Connect server + server = connect_start() + + url = f"sc://localhost:{server.port()}" + + # Initialize Spark Connect session + session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate() + + yield session + + # Cleanup + server.shutdown() + session.stop() diff --git a/tests/connect/test_alias.py b/tests/connect/test_alias.py new file mode 100644 index 0000000000..94efb35fc2 --- /dev/null +++ b/tests/connect/test_alias.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_alias(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Simply rename the 'id' column to 'my_number' + df_renamed = df.select(col("id").alias("my_number")) + + # Verify the alias was set correctly + assert df_renamed.schema != df.schema, "Schema should be changed after alias" + + # Verify the data is unchanged but column name is different + df_rows = df.collect() + df_renamed_rows = df_renamed.collect() + assert [row.id for row in df_rows] == [ + row.my_number for row in df_renamed_rows + ], "Data should be unchanged after alias" diff --git a/tests/connect/test_collect.py b/tests/connect/test_collect.py new file mode 100644 index 0000000000..0a9387dd0b --- /dev/null +++ b/tests/connect/test_collect.py @@ -0,0 +1,14 @@ +from __future__ import annotations + + +def test_range_collect(spark_session): + # Create a range using Spark + # For example, creating a range from 0 to 9 + spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 + + # Collect the data + collected_rows = spark_range.collect() + + # Verify the collected data has expected values + assert len(collected_rows) == 10, "Should have 10 rows" + assert [row["id"] for row in collected_rows] == list(range(10)), "Should contain values 0-9" diff --git a/tests/connect/test_config_simple.py b/tests/connect/test_config_simple.py index de65c7c0f2..9a472f24e2 100644 --- a/tests/connect/test_config_simple.py +++ b/tests/connect/test_config_simple.py @@ -1,29 +1,5 @@ from __future__ import annotations -import time - -import pytest -from pyspark.sql import SparkSession - - -@pytest.fixture -def spark_session(): - """Fixture to create and clean up a Spark session.""" - from daft.daft import connect_start - - # Start Daft Connect server - server = connect_start("sc://localhost:50051") - - # Initialize Spark Connect session - session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() - - yield session - - # Cleanup - server.shutdown() - session.stop() - time.sleep(2) # Allow time for session cleanup - def test_set_operation(spark_session): """Test the Set operation with various data types and edge cases.""" diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index 86f348470e..b277d38481 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -1,26 +1,5 @@ from __future__ import annotations -import pytest -from pyspark.sql import SparkSession - - -@pytest.fixture -def spark_session(): - """Fixture to create and clean up a Spark session.""" - from daft.daft import connect_start - - # Start Daft Connect server - server = connect_start("sc://localhost:50051") - - # Initialize Spark Connect session - session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() - - yield session - - # Cleanup - server.shutdown() - session.stop() - def test_range_operation(spark_session): # Create a range using Spark diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 1a26b8cffc..52b70b4a46 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -2,7 +2,7 @@ import itertools import tempfile -from datetime import datetime, timedelta, timezone +from datetime import date, datetime, timedelta, timezone import pyarrow as pa import pytest @@ -465,3 +465,20 @@ def test_intervals(op, expected): expected = {"datetimes": expected} assert actual == expected + + +@pytest.mark.parametrize( + "value", + [ + date(2020, 1, 1), # explicit date + "2020-01-01", # implicit coercion + ], +) +def test_date_comparison(value): + date_df = daft.from_pydict({"date_str": ["2020-01-01", "2020-01-02", "2020-01-03"]}) + date_df = date_df.with_column("date", col("date_str").str.to_date("%Y-%m-%d")) + actual = date_df.filter(col("date") == value).select("date").to_pydict() + + expected = {"date": [date(2020, 1, 1)]} + + assert actual == expected diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py index 16a98fadf0..1b9aeccc8e 100644 --- a/tests/io/test_s3_credentials_refresh.py +++ b/tests/io/test_s3_credentials_refresh.py @@ -34,9 +34,8 @@ def test_s3_credentials_refresh(aws_log_file: io.IOBase): server_url = f"http://{host}:{port}" bucket_name = "mybucket" - file_name = "test.parquet" - - s3_file_path = f"s3://{bucket_name}/{file_name}" + input_file_path = f"s3://{bucket_name}/input.parquet" + output_file_path = f"s3://{bucket_name}/output.parquet" old_env = os.environ.copy() # Set required AWS environment variables before starting server. @@ -98,21 +97,28 @@ def get_credentials(): ) df = daft.from_pydict({"a": [1, 2, 3]}) - df.write_parquet(s3_file_path, io_config=static_config) + df.write_parquet(input_file_path, io_config=static_config) - df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + df = daft.read_parquet(input_file_path, io_config=dynamic_config) assert count_get_credentials == 1 df.collect() assert count_get_credentials == 1 - df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + df = daft.read_parquet(input_file_path, io_config=dynamic_config) assert count_get_credentials == 1 time.sleep(1) df.collect() assert count_get_credentials == 2 + df.write_parquet(output_file_path, io_config=dynamic_config) + assert count_get_credentials == 2 + + df2 = daft.read_parquet(output_file_path, io_config=static_config) + + assert df.to_arrow() == df2.to_arrow() + # Shutdown moto server. stop_process(process) # Restore old set of environment variables. diff --git a/tests/sql/test_aggs.py b/tests/sql/test_aggs.py index 6d64878070..9e69742b88 100644 --- a/tests/sql/test_aggs.py +++ b/tests/sql/test_aggs.py @@ -1,5 +1,8 @@ +import pytest + import daft from daft import col +from daft.sql import SQLCatalog def test_aggs_sql(): @@ -41,3 +44,63 @@ def test_aggs_sql(): ) assert actual == expected + + +@pytest.mark.parametrize( + "agg,cond,expected", + [ + ("sum(values)", "sum(values) > 10", {"values": [20.5, 29.5]}), + ("sum(values)", "values > 10", {"values": [20.5, 29.5]}), + ("sum(values) as sum_v", "sum(values) > 10", {"sum_v": [20.5, 29.5]}), + ("sum(values) as sum_v", "sum_v > 10", {"sum_v": [20.5, 29.5]}), + ("count(*) as cnt", "cnt > 2", {"cnt": [3, 5]}), + ("count(*) as cnt", "count(*) > 2", {"cnt": [3, 5]}), + ("count(*)", "count(*) > 2", {"count": [3, 5]}), + ("count(*) as cnt", "sum(values) > 10", {"cnt": [3, 5]}), + ("sum(values), count(*)", "id > 1", {"values": [10.0, 29.5], "count": [2, 5]}), + ], +) +def test_having(agg, cond, expected): + df = daft.from_pydict( + { + "id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1], + "values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5], + } + ) + catalog = SQLCatalog({"df": df}) + + actual = daft.sql( + f""" + SELECT + {agg}, + from df + group by id + having {cond} + """, + catalog, + ).to_pydict() + + assert actual == expected + + +def test_having_non_grouped(): + df = daft.from_pydict( + { + "id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1], + "values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5], + "floats": [0.01, 0.011, 0.01047, 0.02, 0.019, 0.018, 0.017, 0.016, 0.015, 0.014], + } + ) + catalog = SQLCatalog({"df": df}) + + actual = daft.sql( + """ + SELECT + count(*) , + from df + having sum(values) > 40 + """, + catalog, + ).to_pydict() + + assert actual == {"count": [10]} diff --git a/tests/sql/test_temporal_exprs.py b/tests/sql/test_temporal_exprs.py index d475850839..9067e6b3d1 100644 --- a/tests/sql/test_temporal_exprs.py +++ b/tests/sql/test_temporal_exprs.py @@ -88,3 +88,11 @@ def test_extract(): """).collect() assert actual.to_pydict() == expected.to_pydict() + + +def test_date_comparison(): + date_df = daft.from_pydict({"date_str": ["2020-01-01", "2020-01-02", "2020-01-03"]}) + date_df = date_df.with_column("date", daft.col("date_str").str.to_date("%Y-%m-%d")) + expected = date_df.filter(daft.col("date") == "2020-01-01").select("date").to_pydict() + actual = daft.sql("select date from date_df where date == '2020-01-01'").to_pydict() + assert actual == expected diff --git a/xyz b/xyz deleted file mode 100644 index e69de29bb2..0000000000