From bfe6560c918565b75378e2a5c77198c46f7b67dd Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 13 Nov 2024 15:03:51 -0800 Subject: [PATCH] [CHORE]: defer Expr subquery error until eval (#3272) Co-authored-by: Kevin Wang --- src/daft-dsl/src/expr/mod.rs | 91 ++++++++++++++++++- src/daft-dsl/src/lib.rs | 2 +- src/daft-dsl/src/optimization.rs | 4 +- src/daft-logical-plan/src/display.rs | 4 +- src/daft-logical-plan/src/logical_plan.rs | 23 ++++- src/daft-logical-plan/src/ops/project.rs | 11 ++- src/daft-logical-plan/src/partitioning.rs | 6 ++ src/daft-sql/src/planner.rs | 24 ++++- src/daft-table/Cargo.toml | 2 +- src/daft-table/src/lib.rs | 6 ++ .../src/schema/io_message/from_message.rs | 12 ++- 11 files changed, 164 insertions(+), 21 deletions(-) diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 8ec9402de2..1dc42837e6 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -36,6 +36,59 @@ use crate::{ optimization::{get_required_columns, requires_computation}, }; +pub trait SubqueryPlan: std::fmt::Debug + std::fmt::Display + Send + Sync { + fn as_any(&self) -> &dyn std::any::Any; + fn name(&self) -> &'static str; + fn schema(&self) -> SchemaRef; +} + +#[derive(Display, Debug, Clone)] +pub struct Subquery { + pub plan: Arc, +} + +impl Subquery { + pub fn new(plan: T) -> Self { + Self { + plan: Arc::new(plan), + } + } + + pub fn schema(&self) -> SchemaRef { + self.plan.schema() + } + pub fn name(&self) -> &'static str { + self.plan.name() + } +} + +impl Serialize for Subquery { + fn serialize(&self, _: S) -> Result { + Err(serde::ser::Error::custom("Subquery cannot be serialized")) + } +} + +impl<'de> Deserialize<'de> for Subquery { + fn deserialize>(_: D) -> Result { + Err(serde::de::Error::custom("Subquery cannot be deserialized")) + } +} + +impl PartialEq for Subquery { + fn eq(&self, other: &Self) -> bool { + self.plan.name() == other.plan.name() && self.plan.schema() == other.plan.schema() + } +} + +impl Eq for Subquery {} + +impl std::hash::Hash for Subquery { + fn hash(&self, state: &mut H) { + self.plan.name().hash(state); + self.plan.schema().hash(state); + } +} + pub type ExprRef = Arc; #[derive(Display, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -95,6 +148,11 @@ pub enum Expr { #[display("{_0}")] ScalarFunction(ScalarFunction), + + #[display("{_0}")] + Subquery(Subquery), + #[display("{_0}, {_1}")] + InSubquery(ExprRef, Subquery), } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)] @@ -582,6 +640,9 @@ impl Expr { pub fn gt_eq(self: ExprRef, other: ExprRef) -> ExprRef { binary_op(Operator::GtEq, self, other) } + pub fn in_subquery(self: ExprRef, subquery: Subquery) -> ExprRef { + Self::InSubquery(self, subquery).into() + } pub fn semantic_id(&self, schema: &Schema) -> FieldID { match self { @@ -647,6 +708,8 @@ impl Expr { // Agg: Separate path. Self::Agg(agg_expr) => agg_expr.semantic_id(schema), Self::ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), + + Self::Subquery(..) | Self::InSubquery(..) => todo!("semantic_id for subquery"), } } @@ -655,13 +718,15 @@ impl Expr { // No children. Self::Column(..) => vec![], Self::Literal(..) => vec![], + Self::Subquery(..) => vec![], // One child. Self::Not(expr) | Self::IsNull(expr) | Self::NotNull(expr) | Self::Cast(expr, ..) - | Self::Alias(expr, ..) => { + | Self::Alias(expr, ..) + | Self::InSubquery(expr, _) => { vec![expr.clone()] } Self::Agg(agg_expr) => agg_expr.children(), @@ -688,7 +753,7 @@ impl Expr { pub fn with_new_children(&self, children: Vec) -> Self { match self { // no children - Self::Column(..) | Self::Literal(..) => { + Self::Column(..) | Self::Literal(..) | Self::Subquery(..) => { assert!(children.is_empty(), "Should have no children"); self.clone() } @@ -708,6 +773,10 @@ impl Expr { children.first().expect("Should have 1 child").clone(), dtype.clone(), ), + Self::InSubquery(_, subquery) => Self::InSubquery( + children.first().expect("Should have 1 child").clone(), + subquery.clone(), + ), // 2 children Self::BinaryOp { op, .. } => Self::BinaryOp { op: *op, @@ -909,6 +978,18 @@ impl Expr { } } } + Self::Subquery(subquery) => { + let subquery_schema = subquery.schema(); + if subquery_schema.len() != 1 { + return Err(DaftError::TypeError(format!( + "Expected subquery to return a single column but received {subquery_schema}", + ))); + } + let (_, first_field) = subquery_schema.fields.first().unwrap(); + + Ok(first_field.clone()) + } + Self::InSubquery(expr, _) => Ok(Field::new(expr.name(), DataType::Boolean)), } } @@ -939,6 +1020,8 @@ impl Expr { right: _, } => left.name(), Self::IfElse { if_true, .. } => if_true.name(), + Self::Subquery(subquery) => subquery.name(), + Self::InSubquery(expr, _) => expr.name(), } } @@ -1011,7 +1094,9 @@ impl Expr { | Expr::Between(..) | Expr::Function { .. } | Expr::FillNull(..) - | Expr::ScalarFunction { .. } => Err(io::Error::new( + | Expr::ScalarFunction { .. } + | Expr::Subquery(..) + | Expr::InSubquery(..) => Err(io::Error::new( io::ErrorKind::Other, "Unsupported expression for SQL translation", )), diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 65af123fed..2de36436b1 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -16,7 +16,7 @@ mod treenode; pub use common_treenode; pub use expr::{ binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr, - ApproxPercentileParams, Expr, ExprRef, Operator, SketchType, + ApproxPercentileParams, Expr, ExprRef, Operator, SketchType, Subquery, SubqueryPlan, }; pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index adafd94b78..68d360cb00 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -33,7 +33,9 @@ pub fn requires_computation(e: &Expr) -> bool { | Expr::FillNull(..) | Expr::IsIn { .. } | Expr::Between { .. } - | Expr::IfElse { .. } => true, + | Expr::IfElse { .. } + | Expr::Subquery { .. } + | Expr::InSubquery { .. } => true, } } diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 88ba77787a..26c7470fa7 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -5,13 +5,13 @@ use common_display::{tree::TreeDisplay, DisplayLevel}; impl TreeDisplay for crate::LogicalPlan { fn display_as(&self, level: DisplayLevel) -> String { match level { - DisplayLevel::Compact => self.name(), + DisplayLevel::Compact => self.name().to_string(), DisplayLevel::Default | DisplayLevel::Verbose => self.multiline_display().join("\n"), } } fn get_name(&self) -> String { - self.name() + self.name().to_string() } fn get_children(&self) -> Vec<&dyn TreeDisplay> { diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 71f6d5bc96..594c07e1a9 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -2,7 +2,7 @@ use std::{num::NonZeroUsize, sync::Arc}; use common_display::ascii::AsciiTreeDisplay; use common_error::DaftError; -use daft_dsl::optimization::get_required_columns; +use daft_dsl::{optimization::get_required_columns, SubqueryPlan}; use daft_schema::schema::SchemaRef; use indexmap::IndexSet; use snafu::Snafu; @@ -173,8 +173,8 @@ impl LogicalPlan { } } - pub fn name(&self) -> String { - let name = match self { + pub fn name(&self) -> &'static str { + match self { Self::Source(..) => "Source", Self::Project(..) => "Project", Self::ActorPoolProject(..) => "ActorPoolProject", @@ -194,8 +194,7 @@ impl LogicalPlan { Self::Sink(..) => "Sink", Self::Sample(..) => "Sample", Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId", - }; - name.to_string() + } } pub fn multiline_display(&self) -> Vec { @@ -327,6 +326,20 @@ impl LogicalPlan { } } +impl SubqueryPlan for LogicalPlan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + Self::name(self) + } + + fn schema(&self) -> SchemaRef { + Self::schema(self) + } +} + #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub(crate) enum Error { diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index a770838b54..163acba2b9 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -199,7 +199,7 @@ fn replace_column_with_semantic_id( Transformed::yes(new_expr.into()) } else { match e.as_ref() { - Expr::Column(_) | Expr::Literal(_) => Transformed::no(e), + Expr::Column(_) | Expr::Literal(_) | Expr::Subquery(_) => Transformed::no(e), Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr( agg_expr.clone(), subexprs_to_replace, @@ -359,6 +359,15 @@ fn replace_column_with_semantic_id( Transformed::yes(Expr::ScalarFunction(func).into()) } } + Expr::InSubquery(expr, subquery) => { + let expr = + replace_column_with_semantic_id(expr.clone(), subexprs_to_replace, schema); + if !expr.transformed { + Transformed::no(e) + } else { + Transformed::yes(Expr::InSubquery(expr.data, subquery.clone()).into()) + } + } } } } diff --git a/src/daft-logical-plan/src/partitioning.rs b/src/daft-logical-plan/src/partitioning.rs index 32ff7f4801..c33a90823d 100644 --- a/src/daft-logical-plan/src/partitioning.rs +++ b/src/daft-logical-plan/src/partitioning.rs @@ -236,6 +236,7 @@ fn translate_clustering_spec_expr( None => Err(()), }, Expr::Literal(_) => Ok(clustering_spec_expr.clone()), + Expr::Subquery(_) => Ok(clustering_spec_expr.clone()), Expr::Alias(child, name) => { let newchild = translate_clustering_spec_expr(child, old_colname_to_new_colname)?; Ok(newchild.alias(name.clone())) @@ -309,6 +310,11 @@ fn translate_clustering_spec_expr( Ok(newpred.if_else(newtrue, newfalse)) } + Expr::InSubquery(expr, subquery) => { + let expr = translate_clustering_spec_expr(expr, old_colname_to_new_colname)?; + + Ok(expr.in_subquery(subquery.clone())) + } // Cannot have agg exprs in clustering specs. Expr::Agg(_) => Err(()), } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 1dbf6512a2..805f11ad47 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -9,6 +9,7 @@ use daft_dsl::{ col, functions::utf8::{ilike, like, to_date, to_datetime}, has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, + Subquery, }; use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; @@ -1128,8 +1129,21 @@ impl SQLPlanner { Ok(expr) } } - SQLExpr::InSubquery { .. } => { - unsupported_sql_err!("IN subquery") + SQLExpr::InSubquery { + expr, + subquery, + negated, + } => { + let expr = self.plan_expr(expr)?; + let mut this = Self::new(self.catalog.clone()); + let subquery = this.plan_query(subquery)?.build(); + let subquery = Subquery { plan: subquery }; + + if *negated { + Ok(expr.in_subquery(subquery).not()) + } else { + Ok(expr.in_subquery(subquery)) + } } SQLExpr::InUnnest { .. } => unsupported_sql_err!("IN UNNEST"), SQLExpr::Between { @@ -1282,7 +1296,11 @@ impl SQLPlanner { ) } SQLExpr::Exists { .. } => unsupported_sql_err!("EXISTS"), - SQLExpr::Subquery(_) => unsupported_sql_err!("SUBQUERY"), + SQLExpr::Subquery(subquery) => { + let mut this = Self::new(self.catalog.clone()); + let subquery = this.plan_query(subquery)?.build(); + Ok(Expr::Subquery(Subquery { plan: subquery }).arced()) + } SQLExpr::GroupingSets(_) => unsupported_sql_err!("GROUPING SETS"), SQLExpr::Cube(_) => unsupported_sql_err!("CUBE"), SQLExpr::Rollup(_) => unsupported_sql_err!("ROLLUP"), diff --git a/src/daft-table/Cargo.toml b/src/daft-table/Cargo.toml index 1c9d25b986..2a68b4124a 100644 --- a/src/daft-table/Cargo.toml +++ b/src/daft-table/Cargo.toml @@ -16,7 +16,7 @@ rand = {workspace = true} serde = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "common-arrow-ffi/python", "common-display/python", "daft-image/python"] +python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "common-arrow-ffi/python", "common-display/python", "daft-image/python", "daft-logical-plan/python"] [lints] workspace = true diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index ede9e7e830..adf490e9a5 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -590,6 +590,12 @@ impl Table { Ok(if_true_series.if_else(&if_false_series, &predicate_series)?) } }, + Subquery(_subquery) => Err(DaftError::ComputeError( + "Subquery should be optimized away before evaluation. This indicates a bug in the query optimizer.".to_string(), + )), + InSubquery(_expr, _subquery) => Err(DaftError::ComputeError( + "IN should be optimized away before evaluation. This indicates a bug in the query optimizer.".to_string(), + )), }?; if expected_field.name != series.field().name { diff --git a/src/parquet2/src/schema/io_message/from_message.rs b/src/parquet2/src/schema/io_message/from_message.rs index 7de7da4edc..3aaedba2cc 100644 --- a/src/parquet2/src/schema/io_message/from_message.rs +++ b/src/parquet2/src/schema/io_message/from_message.rs @@ -45,10 +45,14 @@ use parquet_format_safe::Type; use types::PrimitiveLogicalType; -use super::super::types::{ParquetType, TimeUnit}; -use super::super::*; -use crate::error::{Error, Result}; -use crate::schema::types::{GroupConvertedType, PrimitiveConvertedType}; +use super::super::{ + types::{ParquetType, TimeUnit}, + *, +}; +use crate::{ + error::{Error, Result}, + schema::types::{GroupConvertedType, PrimitiveConvertedType}, +}; fn is_logical_type(s: &str) -> bool { matches!(