Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: preserve expression names when replacing placeholders #12126

Merged
merged 3 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ mod tests {
use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name};

use arrow::array::{self, Int32Array};
use datafusion_common::{Constraint, Constraints, ScalarValue};
use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
Expand Down Expand Up @@ -3699,4 +3699,32 @@ mod tests {
assert!(result.is_err());
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/12065
#[tokio::test]
async fn filtered_aggr_with_param_values() -> Result<()> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test on the main branch gave the following error

Error: Context("type_coercion", SchemaError(FieldNotFound { field: 
    Column { relation: None, name: "count(table1.c2) FILTER (WHERE table1.c3 > $1)" }, 

valid_fields: [
    Column { relation: None, name: "count(table1.c2) FILTER (WHERE table1.c3 > UInt64(10))" }
] }, Some("")))

let cfg = SessionConfig::new()
.set("datafusion.sql_parser.dialect", "PostgreSQL".into());
let ctx = SessionContext::new_with_config(cfg);
register_aggregate_csv(&ctx, "table1").await?;

let df = ctx
.sql("select count (c2) filter (where c3 > $1) from table1")
.await?
.with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)]));

let df_results = df?.collect().await?;
assert_batches_eq!(
&[
"+------------------------------------------------+",
"| count(table1.c2) FILTER (WHERE table1.c3 > $1) |",
"+------------------------------------------------+",
"| 54 |",
"+------------------------------------------------+",
],
&df_results
);

Ok(())
}
}
51 changes: 51 additions & 0 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,57 @@ where
expr.alias_if_changed(original_name)
}

/// Handles ensuring the name of rewritten expressions is not changed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @findepi here is an example of using Expr::Alias as described in #1468 (comment)

///
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
/// See <https://github.com/apache/datafusion/issues/3555> for details
pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
}
}

/// Create a new NamePreserver for rewriting the `expr`s in `Projection`
///
/// This will use aliases
pub fn new_for_projection() -> Self {
Self { use_alias: true }
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
} else {
None
};

Ok(SavedName(original_name))
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
}
}

#[cfg(test)]
mod test {
use std::ops::Add;
Expand Down
23 changes: 14 additions & 9 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use super::dml::CopyTo;
use super::DdlStatement;
use crate::builder::{change_redundant_column, unnest_with_options};
use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction};
use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols};
use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols, NamePreserver};
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::{DmlStatement, Statement};
Expand Down Expand Up @@ -1339,15 +1339,20 @@ impl LogicalPlan {
) -> Result<LogicalPlan> {
self.transform_up_with_subqueries(|plan| {
let schema = Arc::clone(plan.schema());
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
e.infer_placeholder_types(&schema)?.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }) = e {
let value = param_values.get_placeholders_with_values(&id)?;
Ok(Transformed::yes(Expr::Literal(value)))
} else {
Ok(Transformed::no(e))
}
})
let original_name = name_preserver.save(&e)?;
let transformed_expr =
e.infer_placeholder_types(&schema)?.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }) = e {
let value = param_values.get_placeholders_with_values(&id)?;
Ok(Transformed::yes(Expr::Literal(value)))
} else {
Ok(Transformed::no(e))
}
})?;
// Preserve name to avoid breaking column references to this expression
transformed_expr.map_data(|expr| original_name.restore(expr))
})
})
.map(|res| res.data)
Expand Down
55 changes: 4 additions & 51 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator};

use log::{debug, trace};

/// Re-export of `NamesPreserver` for backwards compatibility,
/// as it was initially placed here and then moved elsewhere.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

pub use datafusion_expr::expr_rewriter::NamePreserver;

/// Convenience rule for writing optimizers: recursively invoke
/// optimize on plan's children and then return a node of the same
/// type. Useful for optimizer rules which want to leave the type
Expand Down Expand Up @@ -294,54 +298,3 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
expr_utils::merge_schema(inputs)
}

/// Handles ensuring the name of rewritten expressions is not changed.
///
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
/// See <https://github.com/apache/datafusion/issues/3555> for details
pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
}
}

/// Create a new NamePreserver for rewriting the `expr`s in `Projection`
///
/// This will use aliases
pub fn new_for_projection() -> Self {
Self { use_alias: true }
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
} else {
None
};

Ok(SavedName(original_name))
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
}
}
15 changes: 9 additions & 6 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3813,7 +3813,7 @@ fn test_prepare_statement_to_plan_params_as_constants() {
///////////////////
// replace params with values
let param_values = vec![ScalarValue::Int32(Some(10))];
let expected_plan = "Projection: Int32(10)\n EmptyRelation";
let expected_plan = "Projection: Int32(10) AS $1\n EmptyRelation";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);

Expand All @@ -3829,7 +3829,8 @@ fn test_prepare_statement_to_plan_params_as_constants() {
///////////////////
// replace params with values
let param_values = vec![ScalarValue::Int32(Some(10))];
let expected_plan = "Projection: Int64(1) + Int32(10)\n EmptyRelation";
let expected_plan =
"Projection: Int64(1) + Int32(10) AS Int64(1) + $1\n EmptyRelation";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);

Expand All @@ -3848,7 +3849,9 @@ fn test_prepare_statement_to_plan_params_as_constants() {
ScalarValue::Int32(Some(10)),
ScalarValue::Float64(Some(10.0)),
];
let expected_plan = "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation";
let expected_plan =
"Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2\
\n EmptyRelation";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}
Expand Down Expand Up @@ -4063,7 +4066,7 @@ fn test_prepare_statement_insert_infer() {
\n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \
CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \
CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\
\n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))";
\n Values: (UInt32(1) AS $1, Utf8(\"Alan\") AS $2, Utf8(\"Turing\") AS $3)";
let plan = plan.replace_params_with_values(&param_values).unwrap();

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
Expand Down Expand Up @@ -4144,7 +4147,7 @@ fn test_prepare_statement_to_plan_multi_params() {
ScalarValue::from("xyz"),
];
let expected_plan =
"Projection: person.id, person.age, Utf8(\"xyz\")\
"Projection: person.id, person.age, Utf8(\"xyz\") AS $6\
\n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\
\n TableScan: person";

Expand Down Expand Up @@ -4213,7 +4216,7 @@ fn test_prepare_statement_to_plan_value_list() {
let expected_plan = "Projection: *\
\n SubqueryAlias: t\
\n Projection: column1 AS num, column2 AS letter\
\n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))";
\n Values: (Int64(1), Utf8(\"a\") AS $1), (Int64(2), Utf8(\"b\") AS $2)";

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}
Expand Down