Skip to content

Commit

Permalink
[CHORE] Bring in TreeNode and Refactor Expression Traversal to use Tr…
Browse files Browse the repository at this point in the history
…eeNode (#1676)

* Bring in `TreeNode` / `TreeNodeVisitor` / `TreeNodeRewriter` from
Datafusion
* Implement `RequiredColumns` and `ColumnExpressionRewriter`
  • Loading branch information
samster25 authored Nov 28, 2023
1 parent c3080c9 commit 3f07a85
Show file tree
Hide file tree
Showing 11 changed files with 485 additions and 186 deletions.
9 changes: 9 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ tikv-jemallocator = {version = "0.5.4", features = ["disable_initial_exec_tls"]}
members = [
"src/common/error",
"src/common/io-config",
"src/common/treenode",
"src/daft-core",
"src/daft-io",
"src/daft-parquet",
Expand Down
1 change: 0 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,6 @@ class PyExpr:
def _input_mapping(self) -> str | None: ...
def _required_columns(self) -> set[str]: ...
def _is_column(self) -> bool: ...
def _replace_column_with_expression(self, column: str, new_expr: PyExpr) -> PyExpr: ...
def alias(self, name: str) -> PyExpr: ...
def cast(self, dtype: PyDataType) -> PyExpr: ...
def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ...
Expand Down
3 changes: 0 additions & 3 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,6 @@ def _required_columns(self) -> set[builtins.str]:
def _is_column(self) -> bool:
return self._expr._is_column()

def _replace_column_with_expression(self, column: builtins.str, new_expr: Expression) -> Expression:
return Expression._from_pyexpr(self._expr._replace_column_with_expression(column, new_expr._expr))


SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace")

Expand Down
12 changes: 12 additions & 0 deletions src/common/treenode/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[dependencies]
common-error = {path = "../error", default-features = false}
pyo3 = {workspace = true, optional = true}

[features]
default = ["python"]
python = ["common-error/python"]

[package]
edition = {workspace = true}
name = "common-treenode"
version = {workspace = true}
315 changes: 315 additions & 0 deletions src/common/treenode/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
// This module is borrowed from Datafusion under the Apache 2.0 license
// https://github.com/apache/arrow-datafusion/blob/2a692446f46ef96f48eb9ba19231e9576be9ff5a/datafusion/common/src/tree_node.rs

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! This module provides common traits for visiting or rewriting tree
//! data structures easily.
use common_error::DaftResult;

type Result<T> = DaftResult<T>;

/// Defines a visitable and rewriteable a tree node.
pub trait TreeNode: Sized {
/// Use preorder to iterate the node on the tree so that we can
/// stop fast for some cases.
///
/// The `op` closure can be used to collect some info from the
/// tree node or do some checking for the tree node.
fn apply<F>(&self, op: &mut F) -> Result<VisitRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
{
match op(self)? {
VisitRecursion::Continue => {}
// If the recursion should skip, do not apply to its children. And let the recursion continue
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
// If the recursion should stop, do not apply to its children
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
};

self.apply_children(&mut |node| node.apply(op))
}

/// Visit the tree node using the given [TreeNodeVisitor]
/// It performs a depth first walk of an node and its children.
///
/// For an node tree such as
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(ParentNode)
/// pre_visit(ChildNode1)
/// post_visit(ChildNode1)
/// pre_visit(ChildNode2)
/// post_visit(ChildNode2)
/// post_visit(ParentNode)
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that node will be visited, nor is post_visit
/// called on that node. Details see [`TreeNodeVisitor`]
///
/// If using the default [`TreeNodeVisitor::post_visit`] that does
/// nothing, [`Self::apply`] should be preferred.
fn visit<V: TreeNodeVisitor<N = Self>>(&self, visitor: &mut V) -> Result<VisitRecursion> {
match visitor.pre_visit(self)? {
VisitRecursion::Continue => {}
// If the recursion should skip, do not apply to its children. And let the recursion continue
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
// If the recursion should stop, do not apply to its children
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
};

match self.apply_children(&mut |node| node.visit(visitor))? {
VisitRecursion::Continue => {}
// If the recursion should skip, do not apply to its children. And let the recursion continue
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
// If the recursion should stop, do not apply to its children
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}

visitor.post_visit(self)
}

/// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree.
/// When `op` does not apply to a given node, it is left unchanged.
/// The default tree traversal direction is transform_up(Postorder Traversal).
fn transform<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Transformed<Self>>,
{
self.transform_up(op)
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
/// children(Preorder Traversal).
/// When the `op` does not apply to a given node, it is left unchanged.
fn transform_down<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Transformed<Self>>,
{
let after_op = op(self)?.into();
after_op.map_children(|node| node.transform_down(op))
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
/// children(Preorder Traversal) using a mutable function, `F`.
/// When the `op` does not apply to a given node, it is left unchanged.
fn transform_down_mut<F>(self, op: &mut F) -> Result<Self>
where
F: FnMut(Self) -> Result<Transformed<Self>>,
{
let after_op = op(self)?.into();
after_op.map_children(|node| node.transform_down_mut(op))
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its
/// children and then itself(Postorder Traversal).
/// When the `op` does not apply to a given node, it is left unchanged.
fn transform_up<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Transformed<Self>>,
{
let after_op_children = self.map_children(|node| node.transform_up(op))?;

let new_node = op(after_op_children)?.into();
Ok(new_node)
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its
/// children and then itself(Postorder Traversal) using a mutable function, `F`.
/// When the `op` does not apply to a given node, it is left unchanged.
fn transform_up_mut<F>(self, op: &mut F) -> Result<Self>
where
F: FnMut(Self) -> Result<Transformed<Self>>,
{
let after_op_children = self.map_children(|node| node.transform_up_mut(op))?;

let new_node = op(after_op_children)?.into();
Ok(new_node)
}

/// Transform the tree node using the given [TreeNodeRewriter]
/// It performs a depth first walk of an node and its children.
///
/// For an node tree such as
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(ParentNode)
/// pre_visit(ChildNode1)
/// mutate(ChildNode1)
/// pre_visit(ChildNode2)
/// mutate(ChildNode2)
/// mutate(ParentNode)
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If [`false`] is returned on a call to pre_visit, no
/// children of that node will be visited, nor is mutate
/// called on that node
///
/// If using the default [`TreeNodeRewriter::pre_visit`] which
/// returns `true`, [`Self::transform`] should be preferred.
fn rewrite<R: TreeNodeRewriter<N = Self>>(self, rewriter: &mut R) -> Result<Self> {
let need_mutate = match rewriter.pre_visit(&self)? {
RewriteRecursion::Mutate => return rewriter.mutate(self),
RewriteRecursion::Stop => return Ok(self),
RewriteRecursion::Continue => true,
RewriteRecursion::Skip => false,
};

let after_op_children = self.map_children(|node| node.rewrite(rewriter))?;

// now rewrite this node itself
if need_mutate {
rewriter.mutate(after_op_children)
} else {
Ok(after_op_children)
}
}

/// Apply the closure `F` to the node's children
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>;

/// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder)
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>;
}

/// Implements the [visitor
/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s.
///
/// [`TreeNodeVisitor`] allows keeping the algorithms
/// separate from the code to traverse the structure of the `TreeNode`
/// tree and makes it easier to add new types of tree node and
/// algorithms.
///
/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`]
/// and [`TreeNodeVisitor::post_visit`] are invoked recursively
/// on an node tree.
///
/// If an [`Err`] result is returned, recursion is stopped
/// immediately.
///
/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that tree node are visited, nor is post_visit
/// called on that tree node
///
/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no
/// siblings of that tree node are visited, nor is post_visit
/// called on its parent tree node
///
/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no
/// children of that tree node are visited.
pub trait TreeNodeVisitor: Sized {
/// The node type which is visitable.
type N: TreeNode;

/// Invoked before any children of `node` are visited.
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion>;

/// Invoked after all children of `node` are visited. Default
/// implementation does nothing.
fn post_visit(&mut self, _node: &Self::N) -> Result<VisitRecursion> {
Ok(VisitRecursion::Continue)
}
}

/// Trait for potentially recursively transform an [`TreeNode`] node
/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is
/// invoked recursively on all nodes of a tree.
pub trait TreeNodeRewriter: Sized {
/// The node type which is rewritable.
type N: TreeNode;

/// Invoked before (Preorder) any children of `node` are rewritten /
/// visited. Default implementation returns `Ok(Recursion::Continue)`
fn pre_visit(&mut self, _node: &Self::N) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}

/// Invoked after (Postorder) all children of `node` have been mutated and
/// returns a potentially modified node.
fn mutate(&mut self, node: Self::N) -> Result<Self::N>;
}

/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`].
#[derive(Debug)]
pub enum RewriteRecursion {
/// Continue rewrite this node tree.
Continue,
/// Call 'op' immediately and return.
Mutate,
/// Do not rewrite the children of this node.
Stop,
/// Keep recursive but skip apply op on this node
Skip,
}

/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`].
#[derive(Debug)]
pub enum VisitRecursion {
/// Continue the visit to this node tree.
Continue,
/// Keep recursive but skip applying op on the children
Skip,
/// Stop the visit to this node tree.
Stop,
}

pub enum Transformed<T> {
/// The item was transformed / rewritten somehow
Yes(T),
/// The item was not transformed
No(T),
}

impl<T> Transformed<T> {
pub fn into(self) -> T {
match self {
Transformed::Yes(t) => t,
Transformed::No(t) => t,
}
}

pub fn into_pair(self) -> (T, bool) {
match self {
Transformed::Yes(t) => (t, true),
Transformed::No(t) => (t, false),
}
}
}
3 changes: 2 additions & 1 deletion src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
bincode = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-io-config = {path = "../common/io-config", default-features = false}
common-treenode = {path = "../common/treenode", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
indexmap = {workspace = true}
Expand All @@ -12,7 +13,7 @@ serde_json = {workspace = true}

[features]
default = ["python"]
python = ["dep:pyo3", "dep:pyo3-log", "common-error/python", "daft-core/python", "daft-io/python", "common-io-config/python"]
python = ["dep:pyo3", "dep:pyo3-log", "common-error/python", "daft-core/python", "daft-io/python", "common-io-config/python", "common-treenode/python"]

[package]
edition = {workspace = true}
Expand Down
Loading

0 comments on commit 3f07a85

Please sign in to comment.