From 991eb99913df69b18529fa59f6427a82941b8725 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 28 Nov 2023 16:22:52 -0800 Subject: [PATCH] [CHORE] Favor traversal over visitors (#1677) --- src/daft-dsl/src/optimization.rs | 65 +++++++++----------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index d0d7b524fd..1f99bd2931 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -1,31 +1,19 @@ use std::collections::HashMap; -use common_error::DaftResult; -use common_treenode::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, -}; +use common_treenode::{Transformed, TreeNode, VisitRecursion}; use super::expr::Expr; -struct RequiredColumnVisitor { - required: Vec, -} - -impl TreeNodeVisitor for RequiredColumnVisitor { - type N = Expr; - fn pre_visit(&mut self, node: &Self::N) -> DaftResult { - if let Expr::Column(name) = node { - self.required.push(name.as_ref().into()); - }; - Ok(VisitRecursion::Continue) - } -} - pub fn get_required_columns(e: &Expr) -> Vec { - let mut visitor = RequiredColumnVisitor { required: vec![] }; - e.visit(&mut visitor) - .expect("Error occurred when visiting for required columns"); - visitor.required + let mut cols = vec![]; + e.apply(&mut |expr| { + if let Expr::Column(name) = expr { + cols.push(name.as_ref().into()); + } + Ok(VisitRecursion::Continue) + }) + .expect("Error occurred when visiting for required columns"); + cols } pub fn requires_computation(e: &Expr) -> bool { @@ -43,33 +31,14 @@ pub fn requires_computation(e: &Expr) -> bool { } } -struct ColumnExpressionRewriter<'a> { - mapping: &'a HashMap, -} - -impl<'a> TreeNodeRewriter for ColumnExpressionRewriter<'a> { - type N = Expr; - fn pre_visit(&mut self, node: &Self::N) -> DaftResult { - if let Expr::Column(name) = node && self.mapping.contains_key(name.as_ref()) { - Ok(RewriteRecursion::Continue) - } else { - Ok(RewriteRecursion::Skip) - } - } - fn mutate(&mut self, node: Self::N) -> DaftResult { - if let Expr::Column(ref name) = node && let Some(tgt) = self.mapping.get(name.as_ref()){ - Ok(tgt.clone()) - } else { - Ok(node) - } - } -} - pub fn replace_columns_with_expressions(expr: &Expr, replace_map: &HashMap) -> Expr { - let mut column_rewriter = ColumnExpressionRewriter { - mapping: replace_map, - }; expr.clone() - .rewrite(&mut column_rewriter) + .transform(&|e| { + if let Expr::Column(ref name) = e && let Some(tgt) = replace_map.get(name.as_ref()) { + Ok(Transformed::Yes(tgt.clone())) + } else { + Ok(Transformed::No(e)) + } + }) .expect("Error occurred when rewriting column expressions") }