Skip to content

Commit

Permalink
[CHORE]: defer Expr subquery error until eval (#3272)
Browse files Browse the repository at this point in the history
Co-authored-by: Kevin Wang <[email protected]>
  • Loading branch information
universalmind303 and kevinzwang authored Nov 13, 2024
1 parent a1991e5 commit bfe6560
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 21 deletions.
91 changes: 88 additions & 3 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,59 @@ use crate::{
optimization::{get_required_columns, requires_computation},
};

pub trait SubqueryPlan: std::fmt::Debug + std::fmt::Display + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn name(&self) -> &'static str;
fn schema(&self) -> SchemaRef;
}

#[derive(Display, Debug, Clone)]
pub struct Subquery {
pub plan: Arc<dyn SubqueryPlan>,
}

impl Subquery {
pub fn new<T: SubqueryPlan + 'static>(plan: T) -> Self {
Self {
plan: Arc::new(plan),
}
}

pub fn schema(&self) -> SchemaRef {
self.plan.schema()
}
pub fn name(&self) -> &'static str {
self.plan.name()
}
}

impl Serialize for Subquery {
fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
Err(serde::ser::Error::custom("Subquery cannot be serialized"))
}
}

impl<'de> Deserialize<'de> for Subquery {
fn deserialize<D: serde::Deserializer<'de>>(_: D) -> Result<Self, D::Error> {
Err(serde::de::Error::custom("Subquery cannot be deserialized"))
}
}

impl PartialEq for Subquery {
fn eq(&self, other: &Self) -> bool {
self.plan.name() == other.plan.name() && self.plan.schema() == other.plan.schema()
}
}

impl Eq for Subquery {}

impl std::hash::Hash for Subquery {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.plan.name().hash(state);
self.plan.schema().hash(state);
}
}

pub type ExprRef = Arc<Expr>;

#[derive(Display, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -95,6 +148,11 @@ pub enum Expr {

#[display("{_0}")]
ScalarFunction(ScalarFunction),

#[display("{_0}")]
Subquery(Subquery),
#[display("{_0}, {_1}")]
InSubquery(ExprRef, Subquery),
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)]
Expand Down Expand Up @@ -582,6 +640,9 @@ impl Expr {
pub fn gt_eq(self: ExprRef, other: ExprRef) -> ExprRef {
binary_op(Operator::GtEq, self, other)
}
pub fn in_subquery(self: ExprRef, subquery: Subquery) -> ExprRef {
Self::InSubquery(self, subquery).into()
}

pub fn semantic_id(&self, schema: &Schema) -> FieldID {
match self {
Expand Down Expand Up @@ -647,6 +708,8 @@ impl Expr {
// Agg: Separate path.
Self::Agg(agg_expr) => agg_expr.semantic_id(schema),
Self::ScalarFunction(sf) => scalar_function_semantic_id(sf, schema),

Self::Subquery(..) | Self::InSubquery(..) => todo!("semantic_id for subquery"),
}
}

Expand All @@ -655,13 +718,15 @@ impl Expr {
// No children.
Self::Column(..) => vec![],
Self::Literal(..) => vec![],
Self::Subquery(..) => vec![],

// One child.
Self::Not(expr)
| Self::IsNull(expr)
| Self::NotNull(expr)
| Self::Cast(expr, ..)
| Self::Alias(expr, ..) => {
| Self::Alias(expr, ..)
| Self::InSubquery(expr, _) => {
vec![expr.clone()]
}
Self::Agg(agg_expr) => agg_expr.children(),
Expand All @@ -688,7 +753,7 @@ impl Expr {
pub fn with_new_children(&self, children: Vec<ExprRef>) -> Self {
match self {
// no children
Self::Column(..) | Self::Literal(..) => {
Self::Column(..) | Self::Literal(..) | Self::Subquery(..) => {
assert!(children.is_empty(), "Should have no children");
self.clone()
}
Expand All @@ -708,6 +773,10 @@ impl Expr {
children.first().expect("Should have 1 child").clone(),
dtype.clone(),
),
Self::InSubquery(_, subquery) => Self::InSubquery(
children.first().expect("Should have 1 child").clone(),
subquery.clone(),
),
// 2 children
Self::BinaryOp { op, .. } => Self::BinaryOp {
op: *op,
Expand Down Expand Up @@ -909,6 +978,18 @@ impl Expr {
}
}
}
Self::Subquery(subquery) => {
let subquery_schema = subquery.schema();
if subquery_schema.len() != 1 {
return Err(DaftError::TypeError(format!(
"Expected subquery to return a single column but received {subquery_schema}",
)));
}
let (_, first_field) = subquery_schema.fields.first().unwrap();

Ok(first_field.clone())
}
Self::InSubquery(expr, _) => Ok(Field::new(expr.name(), DataType::Boolean)),
}
}

Expand Down Expand Up @@ -939,6 +1020,8 @@ impl Expr {
right: _,
} => left.name(),
Self::IfElse { if_true, .. } => if_true.name(),
Self::Subquery(subquery) => subquery.name(),
Self::InSubquery(expr, _) => expr.name(),
}
}

Expand Down Expand Up @@ -1011,7 +1094,9 @@ impl Expr {
| Expr::Between(..)
| Expr::Function { .. }
| Expr::FillNull(..)
| Expr::ScalarFunction { .. } => Err(io::Error::new(
| Expr::ScalarFunction { .. }
| Expr::Subquery(..)
| Expr::InSubquery(..) => Err(io::Error::new(
io::ErrorKind::Other,
"Unsupported expression for SQL translation",
)),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod treenode;
pub use common_treenode;
pub use expr::{
binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr,
ApproxPercentileParams, Expr, ExprRef, Operator, SketchType,
ApproxPercentileParams, Expr, ExprRef, Operator, SketchType, Subquery, SubqueryPlan,
};
pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
Expand Down
4 changes: 3 additions & 1 deletion src/daft-dsl/src/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ pub fn requires_computation(e: &Expr) -> bool {
| Expr::FillNull(..)
| Expr::IsIn { .. }
| Expr::Between { .. }
| Expr::IfElse { .. } => true,
| Expr::IfElse { .. }
| Expr::Subquery { .. }
| Expr::InSubquery { .. } => true,
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/daft-logical-plan/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ use common_display::{tree::TreeDisplay, DisplayLevel};
impl TreeDisplay for crate::LogicalPlan {
fn display_as(&self, level: DisplayLevel) -> String {
match level {
DisplayLevel::Compact => self.name(),
DisplayLevel::Compact => self.name().to_string(),
DisplayLevel::Default | DisplayLevel::Verbose => self.multiline_display().join("\n"),
}
}

fn get_name(&self) -> String {
self.name()
self.name().to_string()
}

fn get_children(&self) -> Vec<&dyn TreeDisplay> {
Expand Down
23 changes: 18 additions & 5 deletions src/daft-logical-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{num::NonZeroUsize, sync::Arc};

use common_display::ascii::AsciiTreeDisplay;
use common_error::DaftError;
use daft_dsl::optimization::get_required_columns;
use daft_dsl::{optimization::get_required_columns, SubqueryPlan};
use daft_schema::schema::SchemaRef;
use indexmap::IndexSet;
use snafu::Snafu;
Expand Down Expand Up @@ -173,8 +173,8 @@ impl LogicalPlan {
}
}

pub fn name(&self) -> String {
let name = match self {
pub fn name(&self) -> &'static str {
match self {
Self::Source(..) => "Source",
Self::Project(..) => "Project",
Self::ActorPoolProject(..) => "ActorPoolProject",
Expand All @@ -194,8 +194,7 @@ impl LogicalPlan {
Self::Sink(..) => "Sink",
Self::Sample(..) => "Sample",
Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId",
};
name.to_string()
}
}

pub fn multiline_display(&self) -> Vec<String> {
Expand Down Expand Up @@ -327,6 +326,20 @@ impl LogicalPlan {
}
}

impl SubqueryPlan for LogicalPlan {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &'static str {
Self::name(self)
}

fn schema(&self) -> SchemaRef {
Self::schema(self)
}
}

#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub(crate) enum Error {
Expand Down
11 changes: 10 additions & 1 deletion src/daft-logical-plan/src/ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ fn replace_column_with_semantic_id(
Transformed::yes(new_expr.into())
} else {
match e.as_ref() {
Expr::Column(_) | Expr::Literal(_) => Transformed::no(e),
Expr::Column(_) | Expr::Literal(_) | Expr::Subquery(_) => Transformed::no(e),
Expr::Agg(agg_expr) => replace_column_with_semantic_id_aggexpr(
agg_expr.clone(),
subexprs_to_replace,
Expand Down Expand Up @@ -359,6 +359,15 @@ fn replace_column_with_semantic_id(
Transformed::yes(Expr::ScalarFunction(func).into())
}
}
Expr::InSubquery(expr, subquery) => {
let expr =
replace_column_with_semantic_id(expr.clone(), subexprs_to_replace, schema);
if !expr.transformed {
Transformed::no(e)
} else {
Transformed::yes(Expr::InSubquery(expr.data, subquery.clone()).into())
}
}
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/daft-logical-plan/src/partitioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ fn translate_clustering_spec_expr(
None => Err(()),
},
Expr::Literal(_) => Ok(clustering_spec_expr.clone()),
Expr::Subquery(_) => Ok(clustering_spec_expr.clone()),
Expr::Alias(child, name) => {
let newchild = translate_clustering_spec_expr(child, old_colname_to_new_colname)?;
Ok(newchild.alias(name.clone()))
Expand Down Expand Up @@ -309,6 +310,11 @@ fn translate_clustering_spec_expr(

Ok(newpred.if_else(newtrue, newfalse))
}
Expr::InSubquery(expr, subquery) => {
let expr = translate_clustering_spec_expr(expr, old_colname_to_new_colname)?;

Ok(expr.in_subquery(subquery.clone()))
}
// Cannot have agg exprs in clustering specs.
Expr::Agg(_) => Err(()),
}
Expand Down
24 changes: 21 additions & 3 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use daft_dsl::{
col,
functions::utf8::{ilike, like, to_date, to_datetime},
has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator,
Subquery,
};
use daft_functions::numeric::{ceil::ceil, floor::floor};
use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef};
Expand Down Expand Up @@ -1128,8 +1129,21 @@ impl SQLPlanner {
Ok(expr)
}
}
SQLExpr::InSubquery { .. } => {
unsupported_sql_err!("IN subquery")
SQLExpr::InSubquery {
expr,
subquery,
negated,
} => {
let expr = self.plan_expr(expr)?;
let mut this = Self::new(self.catalog.clone());
let subquery = this.plan_query(subquery)?.build();
let subquery = Subquery { plan: subquery };

if *negated {
Ok(expr.in_subquery(subquery).not())
} else {
Ok(expr.in_subquery(subquery))
}
}
SQLExpr::InUnnest { .. } => unsupported_sql_err!("IN UNNEST"),
SQLExpr::Between {
Expand Down Expand Up @@ -1282,7 +1296,11 @@ impl SQLPlanner {
)
}
SQLExpr::Exists { .. } => unsupported_sql_err!("EXISTS"),
SQLExpr::Subquery(_) => unsupported_sql_err!("SUBQUERY"),
SQLExpr::Subquery(subquery) => {
let mut this = Self::new(self.catalog.clone());
let subquery = this.plan_query(subquery)?.build();
Ok(Expr::Subquery(Subquery { plan: subquery }).arced())
}
SQLExpr::GroupingSets(_) => unsupported_sql_err!("GROUPING SETS"),
SQLExpr::Cube(_) => unsupported_sql_err!("CUBE"),
SQLExpr::Rollup(_) => unsupported_sql_err!("ROLLUP"),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rand = {workspace = true}
serde = {workspace = true}

[features]
python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "common-arrow-ffi/python", "common-display/python", "daft-image/python"]
python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "common-arrow-ffi/python", "common-display/python", "daft-image/python", "daft-logical-plan/python"]

[lints]
workspace = true
Expand Down
6 changes: 6 additions & 0 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ impl Table {
Ok(if_true_series.if_else(&if_false_series, &predicate_series)?)
}
},
Subquery(_subquery) => Err(DaftError::ComputeError(
"Subquery should be optimized away before evaluation. This indicates a bug in the query optimizer.".to_string(),
)),
InSubquery(_expr, _subquery) => Err(DaftError::ComputeError(
"IN <SUBQUERY> should be optimized away before evaluation. This indicates a bug in the query optimizer.".to_string(),
)),
}?;

if expected_field.name != series.field().name {
Expand Down
12 changes: 8 additions & 4 deletions src/parquet2/src/schema/io_message/from_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@
use parquet_format_safe::Type;
use types::PrimitiveLogicalType;

use super::super::types::{ParquetType, TimeUnit};
use super::super::*;
use crate::error::{Error, Result};
use crate::schema::types::{GroupConvertedType, PrimitiveConvertedType};
use super::super::{
types::{ParquetType, TimeUnit},
*,
};
use crate::{
error::{Error, Result},
schema::types::{GroupConvertedType, PrimitiveConvertedType},
};

fn is_logical_type(s: &str) -> bool {
matches!(
Expand Down

0 comments on commit bfe6560

Please sign in to comment.