Skip to content

Commit

Permalink
fix: preserve expression names when replacing placeholders (apache#12126
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jonahgao authored and berkaysynnada committed Aug 26, 2024
1 parent 83d3c5a commit b0292a2
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 67 deletions.
30 changes: 29 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ mod tests {
use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name};

use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
Expand Down Expand Up @@ -3699,4 +3699,32 @@ mod tests {
assert!(result.is_err());
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/12065
#[tokio::test]
async fn filtered_aggr_with_param_values() -> Result<()> {
let cfg = SessionConfig::new()
.set("datafusion.sql_parser.dialect", "PostgreSQL".into());
let ctx = SessionContext::new_with_config(cfg);
register_aggregate_csv(&ctx, "table1").await?;

let df = ctx
.sql("select count (c2) filter (where c3 > $1) from table1")
.await?
.with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)]));

let df_results = df?.collect().await?;
assert_batches_eq!(
&[
"+------------------------------------------------+",
"| count(table1.c2) FILTER (WHERE table1.c3 > $1) |",
"+------------------------------------------------+",
"| 54 |",
"+------------------------------------------------+",
],
&df_results
);

Ok(())
}
}
51 changes: 51 additions & 0 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,57 @@ where
expr.alias_if_changed(original_name)
}

/// Handles ensuring the name of rewritten expressions is not changed.
///
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
/// See <https://github.com/apache/datafusion/issues/3555> for details
pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
}
}

/// Create a new NamePreserver for rewriting the `expr`s in `Projection`
///
/// This will use aliases
pub fn new_for_projection() -> Self {
Self { use_alias: true }
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
} else {
None
};

Ok(SavedName(original_name))
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
}
}

#[cfg(test)]
mod test {
use std::ops::Add;
Expand Down
23 changes: 14 additions & 9 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use super::dml::CopyTo;
use super::DdlStatement;
use crate::builder::{change_redundant_column, unnest_with_options};
use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction};
use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols};
use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols, NamePreserver};
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::{DmlStatement, Statement};
Expand Down Expand Up @@ -1339,15 +1339,20 @@ impl LogicalPlan {
) -> Result<LogicalPlan> {
self.transform_up_with_subqueries(|plan| {
let schema = Arc::clone(plan.schema());
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
e.infer_placeholder_types(&schema)?.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }) = e {
let value = param_values.get_placeholders_with_values(&id)?;
Ok(Transformed::yes(Expr::Literal(value)))
} else {
Ok(Transformed::no(e))
}
})
let original_name = name_preserver.save(&e)?;
let transformed_expr =
e.infer_placeholder_types(&schema)?.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }) = e {
let value = param_values.get_placeholders_with_values(&id)?;
Ok(Transformed::yes(Expr::Literal(value)))
} else {
Ok(Transformed::no(e))
}
})?;
// Preserve name to avoid breaking column references to this expression
transformed_expr.map_data(|expr| original_name.restore(expr))
})
})
.map(|res| res.data)
Expand Down
55 changes: 4 additions & 51 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator};

use log::{debug, trace};

/// Re-export of `NamesPreserver` for backwards compatibility,
/// as it was initially placed here and then moved elsewhere.
pub use datafusion_expr::expr_rewriter::NamePreserver;

/// Convenience rule for writing optimizers: recursively invoke
/// optimize on plan's children and then return a node of the same
/// type. Useful for optimizer rules which want to leave the type
Expand Down Expand Up @@ -294,54 +298,3 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
expr_utils::merge_schema(inputs)
}

/// Handles ensuring the name of rewritten expressions is not changed.
///
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
/// See <https://github.com/apache/datafusion/issues/3555> for details
pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
}
}

/// Create a new NamePreserver for rewriting the `expr`s in `Projection`
///
/// This will use aliases
pub fn new_for_projection() -> Self {
Self { use_alias: true }
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
} else {
None
};

Ok(SavedName(original_name))
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
}
}
15 changes: 9 additions & 6 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3813,7 +3813,7 @@ fn test_prepare_statement_to_plan_params_as_constants() {
///////////////////
// replace params with values
let param_values = vec![ScalarValue::Int32(Some(10))];
let expected_plan = "Projection: Int32(10)\n EmptyRelation";
let expected_plan = "Projection: Int32(10) AS $1\n EmptyRelation";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);

Expand All @@ -3829,7 +3829,8 @@ fn test_prepare_statement_to_plan_params_as_constants() {
///////////////////
// replace params with values
let param_values = vec![ScalarValue::Int32(Some(10))];
let expected_plan = "Projection: Int64(1) + Int32(10)\n EmptyRelation";
let expected_plan =
"Projection: Int64(1) + Int32(10) AS Int64(1) + $1\n EmptyRelation";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);

Expand All @@ -3848,7 +3849,9 @@ fn test_prepare_statement_to_plan_params_as_constants() {
ScalarValue::Int32(Some(10)),
ScalarValue::Float64(Some(10.0)),
];
let expected_plan = "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation";
let expected_plan =
"Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2\
\n EmptyRelation";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}
Expand Down Expand Up @@ -4063,7 +4066,7 @@ fn test_prepare_statement_insert_infer() {
\n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \
CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \
CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\
\n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))";
\n Values: (UInt32(1) AS $1, Utf8(\"Alan\") AS $2, Utf8(\"Turing\") AS $3)";
let plan = plan.replace_params_with_values(&param_values).unwrap();

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
Expand Down Expand Up @@ -4144,7 +4147,7 @@ fn test_prepare_statement_to_plan_multi_params() {
ScalarValue::from("xyz"),
];
let expected_plan =
"Projection: person.id, person.age, Utf8(\"xyz\")\
"Projection: person.id, person.age, Utf8(\"xyz\") AS $6\
\n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\
\n TableScan: person";

Expand Down Expand Up @@ -4213,7 +4216,7 @@ fn test_prepare_statement_to_plan_value_list() {
let expected_plan = "Projection: *\
\n SubqueryAlias: t\
\n Projection: column1 AS num, column2 AS letter\
\n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))";
\n Values: (Int64(1), Utf8(\"a\") AS $1), (Int64(2), Utf8(\"b\") AS $2)";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}
Expand Down

0 comments on commit b0292a2

Please sign in to comment.