Skip to content

Commit

Permalink
Add Container trait and to simplify Expr and LogicalPlan apply …
Browse files Browse the repository at this point in the history
…and map methods (#13467)

* Add `Container` trait and its blanket implementations, remove `map_until_stop_and_collect` macro, simplify apply and map logic with `Container`s where possible

* fix clippy

* rename `Container` to `TreeNodeContainer`

* add docs to containers

* clarify when we need a temporary `TreeNodeRefContainer`

* code and docs cleanup
  • Loading branch information
peter-toth authored Nov 20, 2024
1 parent 30ff48e commit aef232b
Show file tree
Hide file tree
Showing 9 changed files with 687 additions and 580 deletions.
363 changes: 313 additions & 50 deletions datafusion/common/src/tree_node.rs

Large diffs are not rendered by default.

36 changes: 35 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cse::HashNode;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
use datafusion_common::{
plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference,
Expand Down Expand Up @@ -351,6 +351,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr {
}
}

impl<'a> TreeNodeContainer<'a, Self> for Expr {
fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)
}

fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
f(self)
}
}

/// UNNEST expression.
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Unnest {
Expand Down Expand Up @@ -653,6 +669,24 @@ impl Display for Sort {
}
}

impl<'a> TreeNodeContainer<'a, Expr> for Sort {
fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
self.expr.apply_elements(f)
}

fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.expr
.map_elements(f)?
.map_data(|expr| Ok(Self { expr, ..self }))
}
}

/// Aggregate function
///
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
Expand Down
50 changes: 49 additions & 1 deletion datafusion/expr/src/logical_plan/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ use std::{

use crate::expr::Sort;
use arrow::datatypes::DataType;
use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference};
use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion};
use datafusion_common::{
Constraints, DFSchemaRef, Result, SchemaReference, TableReference,
};
use sqlparser::ast::Ident;

/// Various types of DDL (CREATE / DROP) catalog manipulation
Expand Down Expand Up @@ -487,6 +490,28 @@ pub struct OperateFunctionArg {
pub data_type: DataType,
pub default_expr: Option<Expr>,
}

impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg {
fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
self.default_expr.apply_elements(f)
}

fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.default_expr.map_elements(f)?.map_data(|default_expr| {
Ok(Self {
default_expr,
..self
})
})
}
}

#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct CreateFunctionBody {
/// LANGUAGE lang_name
Expand All @@ -497,6 +522,29 @@ pub struct CreateFunctionBody {
pub function_body: Option<Expr>,
}

impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody {
fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
self.function_body.apply_elements(f)
}

fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.function_body
.map_elements(f)?
.map_data(|function_body| {
Ok(Self {
function_body,
..self
})
})
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct DropFunction {
pub name: String,
Expand Down
20 changes: 19 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ use crate::{
};

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
Expand Down Expand Up @@ -287,6 +289,22 @@ impl Default for LogicalPlan {
}
}

impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan {
fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)
}

fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
f(self)
}
}

impl LogicalPlan {
/// Get a reference to the logical plan's schema
pub fn schema(&self) -> &DFSchemaRef {
Expand Down
51 changes: 1 addition & 50 deletions datafusion/expr/src/logical_plan/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::tree_node::{Transformed, TreeNodeIterator};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use datafusion_common::{DFSchema, DFSchemaRef};
use std::fmt::{self, Display};
use std::sync::{Arc, OnceLock};

use super::tree_node::rewrite_arc;
use crate::{expr_vec_fmt, Expr, LogicalPlan};

/// Statements have a unchanging empty schema.
Expand Down Expand Up @@ -80,53 +78,6 @@ impl Statement {
}
}

/// Rewrites input LogicalPlans in the current `Statement` using `f`.
pub(super) fn map_inputs<
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
>(
self,
f: F,
) -> Result<Transformed<Self>> {
match self {
Statement::Prepare(Prepare {
input,
name,
data_types,
}) => Ok(rewrite_arc(input, f)?.update_data(|input| {
Statement::Prepare(Prepare {
input,
name,
data_types,
})
})),
_ => Ok(Transformed::no(self)),
}
}

/// Returns a iterator over all expressions in the current `Statement`.
pub(super) fn expression_iter(&self) -> impl Iterator<Item = &Expr> {
match self {
Statement::Execute(Execute { parameters, .. }) => parameters.iter(),
_ => [].iter(),
}
}

/// Rewrites all expressions in the current `Statement` using `f`.
pub(super) fn map_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
match self {
Statement::Execute(Execute { name, parameters }) => Ok(parameters
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|parameters| {
Statement::Execute(Execute { parameters, name })
})),
_ => Ok(Transformed::no(self)),
}
}

/// Return a `format`able structure with the a human readable
/// description of this LogicalPlan node per node, not including
/// children.
Expand Down
Loading

0 comments on commit aef232b

Please sign in to comment.