From 3f07a859167c926365f3295e5c8758a2f2fc252a Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 28 Nov 2023 13:47:55 -0800 Subject: [PATCH] [CHORE] Bring in TreeNode and Refactor Expression Traversal to use TreeNode (#1676) * Bring in `TreeNode` / `TreeNodeVisitor` / `TreeNodeRewriter` from Datafusion * Implement `RequiredColumns` and `ColumnExpressionRewriter` --- Cargo.lock | 9 + Cargo.toml | 1 + daft/daft.pyi | 1 - daft/expressions/expressions.py | 3 - src/common/treenode/Cargo.toml | 12 ++ src/common/treenode/src/lib.rs | 315 +++++++++++++++++++++++++++++++ src/daft-dsl/Cargo.toml | 3 +- src/daft-dsl/src/lib.rs | 2 + src/daft-dsl/src/optimization.rs | 224 +++++----------------- src/daft-dsl/src/python.rs | 6 - src/daft-dsl/src/treenode.rs | 95 ++++++++++ 11 files changed, 485 insertions(+), 186 deletions(-) create mode 100644 src/common/treenode/Cargo.toml create mode 100644 src/common/treenode/src/lib.rs create mode 100644 src/daft-dsl/src/treenode.rs diff --git a/Cargo.lock b/Cargo.lock index 5c205b3e95..672fec9c08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -919,6 +919,14 @@ dependencies = [ "serde_json", ] +[[package]] +name = "common-treenode" +version = "0.1.10" +dependencies = [ + "common-error", + "pyo3", +] + [[package]] name = "concurrent-queue" version = "2.3.0" @@ -1191,6 +1199,7 @@ dependencies = [ "bincode", "common-error", "common-io-config", + "common-treenode", "daft-core", "daft-io", "indexmap 2.1.0", diff --git a/Cargo.toml b/Cargo.toml index 7823f3791f..7fdf514a0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ tikv-jemallocator = {version = "0.5.4", features = ["disable_initial_exec_tls"]} members = [ "src/common/error", "src/common/io-config", + "src/common/treenode", "src/daft-core", "src/daft-io", "src/daft-parquet", diff --git a/daft/daft.pyi b/daft/daft.pyi index 8960b54d25..3de0124309 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -689,7 +689,6 @@ class PyExpr: def _input_mapping(self) -> str | None: ... def _required_columns(self) -> set[str]: ... def _is_column(self) -> bool: ... - def _replace_column_with_expression(self, column: str, new_expr: PyExpr) -> PyExpr: ... def alias(self, name: str) -> PyExpr: ... def cast(self, dtype: PyDataType) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 39ffc46bf7..cb3769082b 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -398,9 +398,6 @@ def _required_columns(self) -> set[builtins.str]: def _is_column(self) -> bool: return self._expr._is_column() - def _replace_column_with_expression(self, column: builtins.str, new_expr: Expression) -> Expression: - return Expression._from_pyexpr(self._expr._replace_column_with_expression(column, new_expr._expr)) - SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace") diff --git a/src/common/treenode/Cargo.toml b/src/common/treenode/Cargo.toml new file mode 100644 index 0000000000..b8c07684cf --- /dev/null +++ b/src/common/treenode/Cargo.toml @@ -0,0 +1,12 @@ +[dependencies] +common-error = {path = "../error", default-features = false} +pyo3 = {workspace = true, optional = true} + +[features] +default = ["python"] +python = ["common-error/python"] + +[package] +edition = {workspace = true} +name = "common-treenode" +version = {workspace = true} diff --git a/src/common/treenode/src/lib.rs b/src/common/treenode/src/lib.rs new file mode 100644 index 0000000000..e91ddddddd --- /dev/null +++ b/src/common/treenode/src/lib.rs @@ -0,0 +1,315 @@ +// This module is borrowed from Datafusion under the Apache 2.0 license +// https://github.com/apache/arrow-datafusion/blob/2a692446f46ef96f48eb9ba19231e9576be9ff5a/datafusion/common/src/tree_node.rs + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides common traits for visiting or rewriting tree +//! data structures easily. + +use common_error::DaftResult; + +type Result = DaftResult; + +/// Defines a visitable and rewriteable a tree node. +pub trait TreeNode: Sized { + /// Use preorder to iterate the node on the tree so that we can + /// stop fast for some cases. + /// + /// The `op` closure can be used to collect some info from the + /// tree node or do some checking for the tree node. + fn apply(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + match op(self)? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + }; + + self.apply_children(&mut |node| node.apply(op)) + } + + /// Visit the tree node using the given [TreeNodeVisitor] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// post_visit(ChildNode1) + /// pre_visit(ChildNode2) + /// post_visit(ChildNode2) + /// post_visit(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no + /// children of that node will be visited, nor is post_visit + /// called on that node. Details see [`TreeNodeVisitor`] + /// + /// If using the default [`TreeNodeVisitor::post_visit`] that does + /// nothing, [`Self::apply`] should be preferred. + fn visit>(&self, visitor: &mut V) -> Result { + match visitor.pre_visit(self)? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + }; + + match self.apply_children(&mut |node| node.visit(visitor))? { + VisitRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + + visitor.post_visit(self) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. + /// When `op` does not apply to a given node, it is left unchanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + self.transform_up(op) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let after_op = op(self)?.into(); + after_op.map_children(|node| node.transform_down(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op = op(self)?.into(); + after_op.map_children(|node| node.transform_down_mut(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up(op))?; + + let new_node = op(after_op_children)?.into(); + Ok(new_node) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; + + let new_node = op(after_op_children)?.into(); + Ok(new_node) + } + + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node will be visited, nor is mutate + /// called on that node + /// + /// If using the default [`TreeNodeRewriter::pre_visit`] which + /// returns `true`, [`Self::transform`] should be preferred. + fn rewrite>(self, rewriter: &mut R) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + + /// Apply the closure `F` to the node's children + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result; + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result; +} + +/// Implements the [visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. +/// +/// [`TreeNodeVisitor`] allows keeping the algorithms +/// separate from the code to traverse the structure of the `TreeNode` +/// tree and makes it easier to add new types of tree node and +/// algorithms. +/// +/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] +/// and [`TreeNodeVisitor::post_visit`] are invoked recursively +/// on an node tree. +/// +/// If an [`Err`] result is returned, recursion is stopped +/// immediately. +/// +/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no +/// children of that tree node are visited, nor is post_visit +/// called on that tree node +/// +/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no +/// siblings of that tree node are visited, nor is post_visit +/// called on its parent tree node +/// +/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no +/// children of that tree node are visited. +pub trait TreeNodeVisitor: Sized { + /// The node type which is visitable. + type N: TreeNode; + + /// Invoked before any children of `node` are visited. + fn pre_visit(&mut self, node: &Self::N) -> Result; + + /// Invoked after all children of `node` are visited. Default + /// implementation does nothing. + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(VisitRecursion::Continue) + } +} + +/// Trait for potentially recursively transform an [`TreeNode`] node +/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is +/// invoked recursively on all nodes of a tree. +pub trait TreeNodeRewriter: Sized { + /// The node type which is rewritable. + type N: TreeNode; + + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(Recursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: Self::N) -> Result; +} + +/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`]. +#[derive(Debug)] +pub enum RewriteRecursion { + /// Continue rewrite this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} + +/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. +#[derive(Debug)] +pub enum VisitRecursion { + /// Continue the visit to this node tree. + Continue, + /// Keep recursive but skip applying op on the children + Skip, + /// Stop the visit to this node tree. + Stop, +} + +pub enum Transformed { + /// The item was transformed / rewritten somehow + Yes(T), + /// The item was not transformed + No(T), +} + +impl Transformed { + pub fn into(self) -> T { + match self { + Transformed::Yes(t) => t, + Transformed::No(t) => t, + } + } + + pub fn into_pair(self) -> (T, bool) { + match self { + Transformed::Yes(t) => (t, true), + Transformed::No(t) => (t, false), + } + } +} diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index 8d7489338e..002719ca01 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -2,6 +2,7 @@ bincode = {workspace = true} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-io = {path = "../daft-io", default-features = false} indexmap = {workspace = true} @@ -12,7 +13,7 @@ serde_json = {workspace = true} [features] default = ["python"] -python = ["dep:pyo3", "dep:pyo3-log", "common-error/python", "daft-core/python", "daft-io/python", "common-io-config/python"] +python = ["dep:pyo3", "dep:pyo3-log", "common-error/python", "daft-core/python", "daft-io/python", "common-io-config/python", "common-treenode/python"] [package] edition = {workspace = true} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index daac00e109..a38ab5d46e 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(let_chains)] mod arithmetic; mod expr; pub mod functions; @@ -7,6 +8,7 @@ pub mod optimization; mod pyobject; #[cfg(feature = "python")] pub mod python; +mod treenode; pub use expr::binary_op; pub use expr::col; pub use expr::{AggExpr, Expr, ExprRef, Operator}; diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index ff1d808382..d0d7b524fd 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -1,48 +1,33 @@ use std::collections::HashMap; -use super::expr::{AggExpr, Expr}; +use common_error::DaftResult; +use common_treenode::{ + RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, +}; -pub fn get_required_columns(e: &Expr) -> Vec { - // Returns all the column names required by this expression - match e { - Expr::Alias(child, _) => get_required_columns(child), - Expr::Agg(agg) => match agg { - AggExpr::Count(child, ..) - | AggExpr::Sum(child) - | AggExpr::Mean(child) - | AggExpr::Min(child) - | AggExpr::Max(child) - | AggExpr::List(child) - | AggExpr::Concat(child) => get_required_columns(child), - }, - Expr::BinaryOp { left, right, .. } => { - let mut req_cols = get_required_columns(left); - req_cols.extend(get_required_columns(right)); - req_cols - } - Expr::Cast(child, _) => get_required_columns(child), - Expr::Column(name) => vec![name.to_string()], - Expr::Function { inputs, .. } => { - let child_required_columns: Vec> = - inputs.iter().map(get_required_columns).collect(); - child_required_columns.concat() - } - Expr::Not(child) => get_required_columns(child), - Expr::IsNull(child) => get_required_columns(child), - Expr::Literal(_) => vec![], - Expr::IfElse { - if_true, - if_false, - predicate, - } => { - let mut req_cols = get_required_columns(if_true); - req_cols.extend(get_required_columns(if_false)); - req_cols.extend(get_required_columns(predicate)); - req_cols - } +use super::expr::Expr; + +struct RequiredColumnVisitor { + required: Vec, +} + +impl TreeNodeVisitor for RequiredColumnVisitor { + type N = Expr; + fn pre_visit(&mut self, node: &Self::N) -> DaftResult { + if let Expr::Column(name) = node { + self.required.push(name.as_ref().into()); + }; + Ok(VisitRecursion::Continue) } } +pub fn get_required_columns(e: &Expr) -> Vec { + let mut visitor = RequiredColumnVisitor { required: vec![] }; + e.visit(&mut visitor) + .expect("Error occurred when visiting for required columns"); + visitor.required +} + pub fn requires_computation(e: &Expr) -> bool { // Returns whether or not this expression runs any computation on the underlying data match e { @@ -58,144 +43,33 @@ pub fn requires_computation(e: &Expr) -> bool { } } -pub fn replace_columns_with_expressions(expr: &Expr, replace_map: &HashMap) -> Expr { - // Constructs a new deep-copied Expr which is `expr` but with all occurrences of Column(column_name) recursively - // replaced with `new_expr` for all column_name -> new_expr mappings in replace_map. - match expr { - // BASE CASE: found a matching column - Expr::Column(name) => match replace_map.get(&name.to_string()) { - Some(replacement) => replacement.clone(), - None => expr.clone(), - }, - - // BASE CASE: reached non-matching leaf node - Expr::Literal(_) => expr.clone(), +struct ColumnExpressionRewriter<'a> { + mapping: &'a HashMap, +} - // RECURSIVE CASE: recursively replace for matching column - Expr::Alias(child, name) => Expr::Alias( - replace_columns_with_expressions(child, replace_map).into(), - (*name).clone(), - ), - Expr::Agg(agg) => match agg { - AggExpr::Count(child, mode) => Expr::Agg(AggExpr::Count( - replace_columns_with_expressions(child, replace_map).into(), - *mode, - )), - AggExpr::Sum(child) => Expr::Agg(AggExpr::Sum( - replace_columns_with_expressions(child, replace_map).into(), - )), - AggExpr::Mean(child) => Expr::Agg(AggExpr::Mean( - replace_columns_with_expressions(child, replace_map).into(), - )), - AggExpr::Min(child) => Expr::Agg(AggExpr::Min( - replace_columns_with_expressions(child, replace_map).into(), - )), - AggExpr::Max(child) => Expr::Agg(AggExpr::Max( - replace_columns_with_expressions(child, replace_map).into(), - )), - AggExpr::List(child) => Expr::Agg(AggExpr::List( - replace_columns_with_expressions(child, replace_map).into(), - )), - AggExpr::Concat(child) => Expr::Agg(AggExpr::List( - replace_columns_with_expressions(child, replace_map).into(), - )), - }, - Expr::BinaryOp { left, right, op } => Expr::BinaryOp { - op: *op, - left: replace_columns_with_expressions(left, replace_map).into(), - right: replace_columns_with_expressions(right, replace_map).into(), - }, - Expr::Cast(child, name) => Expr::Cast( - replace_columns_with_expressions(child, replace_map).into(), - (*name).clone(), - ), - Expr::Function { inputs, func } => Expr::Function { - func: func.clone(), - inputs: inputs - .iter() - .map(|e| replace_columns_with_expressions(e, replace_map)) - .collect(), - }, - Expr::Not(child) => replace_columns_with_expressions(child, replace_map), - Expr::IsNull(child) => replace_columns_with_expressions(child, replace_map), - Expr::IfElse { - if_true, - if_false, - predicate, - } => Expr::IfElse { - if_true: replace_columns_with_expressions(if_true, replace_map).into(), - if_false: replace_columns_with_expressions(if_false, replace_map).into(), - predicate: replace_columns_with_expressions(predicate, replace_map).into(), - }, +impl<'a> TreeNodeRewriter for ColumnExpressionRewriter<'a> { + type N = Expr; + fn pre_visit(&mut self, node: &Self::N) -> DaftResult { + if let Expr::Column(name) = node && self.mapping.contains_key(name.as_ref()) { + Ok(RewriteRecursion::Continue) + } else { + Ok(RewriteRecursion::Skip) + } + } + fn mutate(&mut self, node: Self::N) -> DaftResult { + if let Expr::Column(ref name) = node && let Some(tgt) = self.mapping.get(name.as_ref()){ + Ok(tgt.clone()) + } else { + Ok(node) + } } } -pub fn replace_column_with_expression(expr: &Expr, column_name: &str, new_expr: &Expr) -> Expr { - // Constructs a new deep-copied Expr which is `expr` but with all occurrences of Column(column_name) recursively - // replaced with `new_expr` - match expr { - // BASE CASE: found a matching column - Expr::Column(name) if name.as_ref().eq(column_name) => new_expr.clone(), - - // BASE CASE: reached non-matching leaf node - Expr::Column(..) => expr.clone(), - Expr::Literal(_) => expr.clone(), - - // RECURSIVE CASE: recursively replace for matching column - Expr::Alias(child, name) => Expr::Alias( - replace_column_with_expression(child, column_name, new_expr).into(), - (*name).clone(), - ), - Expr::Agg(agg) => match agg { - AggExpr::Count(child, mode) => Expr::Agg(AggExpr::Count( - replace_column_with_expression(child, column_name, new_expr).into(), - *mode, - )), - AggExpr::Sum(child) => Expr::Agg(AggExpr::Sum( - replace_column_with_expression(child, column_name, new_expr).into(), - )), - AggExpr::Mean(child) => Expr::Agg(AggExpr::Mean( - replace_column_with_expression(child, column_name, new_expr).into(), - )), - AggExpr::Min(child) => Expr::Agg(AggExpr::Min( - replace_column_with_expression(child, column_name, new_expr).into(), - )), - AggExpr::Max(child) => Expr::Agg(AggExpr::Max( - replace_column_with_expression(child, column_name, new_expr).into(), - )), - AggExpr::List(child) => Expr::Agg(AggExpr::List( - replace_column_with_expression(child, column_name, new_expr).into(), - )), - AggExpr::Concat(child) => Expr::Agg(AggExpr::List( - replace_column_with_expression(child, column_name, new_expr).into(), - )), - }, - Expr::BinaryOp { left, right, op } => Expr::BinaryOp { - op: *op, - left: replace_column_with_expression(left, column_name, new_expr).into(), - right: replace_column_with_expression(right, column_name, new_expr).into(), - }, - Expr::Cast(child, name) => Expr::Cast( - replace_column_with_expression(child, column_name, new_expr).into(), - (*name).clone(), - ), - Expr::Function { inputs, func } => Expr::Function { - func: func.clone(), - inputs: inputs - .iter() - .map(|e| replace_column_with_expression(e, column_name, new_expr)) - .collect(), - }, - Expr::Not(child) => replace_column_with_expression(child, column_name, new_expr), - Expr::IsNull(child) => replace_column_with_expression(child, column_name, new_expr), - Expr::IfElse { - if_true, - if_false, - predicate, - } => Expr::IfElse { - if_true: replace_column_with_expression(if_true, column_name, new_expr).into(), - if_false: replace_column_with_expression(if_false, column_name, new_expr).into(), - predicate: replace_column_with_expression(predicate, column_name, new_expr).into(), - }, - } +pub fn replace_columns_with_expressions(expr: &Expr, replace_map: &HashMap) -> Expr { + let mut column_rewriter = ColumnExpressionRewriter { + mapping: replace_map, + }; + expr.clone() + .rewrite(&mut column_rewriter) + .expect("Error occurred when rewriting column expressions") } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index eb88d0ee76..1a04d28c46 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -116,12 +116,6 @@ impl PyExpr { Ok(matches!(self.expr, Expr::Column(..))) } - pub fn _replace_column_with_expression(&self, column: &str, new_expr: &Self) -> PyResult { - Ok(PyExpr { - expr: optimization::replace_column_with_expression(&self.expr, column, &new_expr.expr), - }) - } - pub fn alias(&self, name: &str) -> PyResult { Ok(self.expr.alias(name).into()) } diff --git a/src/daft-dsl/src/treenode.rs b/src/daft-dsl/src/treenode.rs new file mode 100644 index 0000000000..7895ebf42d --- /dev/null +++ b/src/daft-dsl/src/treenode.rs @@ -0,0 +1,95 @@ +use common_error::DaftResult; +use common_treenode::{TreeNode, VisitRecursion}; + +use crate::Expr; + +impl TreeNode for Expr { + fn apply_children(&self, op: &mut F) -> DaftResult + where + F: FnMut(&Self) -> DaftResult, + { + use Expr::*; + let children = match self { + Alias(expr, _) | Cast(expr, _) | Not(expr) | IsNull(expr) => { + vec![expr.as_ref()] + } + Agg(agg_expr) => { + use crate::AggExpr::*; + match agg_expr { + Count(expr, ..) + | Sum(expr) + | Mean(expr) + | Min(expr) + | Max(expr) + | List(expr) + | Concat(expr) => vec![expr.as_ref()], + } + } + BinaryOp { op: _, left, right } => vec![left.as_ref(), right.as_ref()], + Column(_) | Literal(_) => vec![], + Function { func: _, inputs } => inputs.iter().collect::>(), + IfElse { + if_true, + if_false, + predicate, + } => vec![if_true.as_ref(), if_false.as_ref(), predicate.as_ref()], + }; + for child in children.into_iter() { + match op(child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children(self, transform: F) -> DaftResult + where + F: FnMut(Self) -> DaftResult, + { + let mut transform = transform; + + use Expr::*; + Ok(match self { + Alias(expr, name) => Alias(transform(expr.as_ref().clone())?.into(), name), + Column(_) | Literal(_) => self, + Cast(expr, dtype) => Cast(transform(expr.as_ref().clone())?.into(), dtype), + Agg(agg_expr) => { + use crate::AggExpr::*; + match agg_expr { + Count(expr, mode) => transform(expr.as_ref().clone())?.count(mode), + Sum(expr) => transform(expr.as_ref().clone())?.sum(), + Mean(expr) => transform(expr.as_ref().clone())?.mean(), + Min(expr) => transform(expr.as_ref().clone())?.min(), + Max(expr) => transform(expr.as_ref().clone())?.max(), + List(expr) => transform(expr.as_ref().clone())?.agg_list(), + Concat(expr) => transform(expr.as_ref().clone())?.agg_concat(), + } + } + Not(expr) => Not(transform(expr.as_ref().clone())?.into()), + IsNull(expr) => IsNull(transform(expr.as_ref().clone())?.into()), + IfElse { + if_true, + if_false, + predicate, + } => Expr::IfElse { + if_true: transform(if_true.as_ref().clone())?.into(), + if_false: transform(if_false.as_ref().clone())?.into(), + predicate: transform(predicate.as_ref().clone())?.into(), + }, + BinaryOp { op, left, right } => Expr::BinaryOp { + op, + left: transform(left.as_ref().clone())?.into(), + right: transform(right.as_ref().clone())?.into(), + }, + Function { func, inputs } => Function { + func, + inputs: inputs + .into_iter() + .map(transform) + .collect::>>()?, + }, + }) + } +}