Skip to content

Commit

Permalink
[CHORE] Favor traversal over visitors (#1677)
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 authored Nov 29, 2023
1 parent 3f07a85 commit 991eb99
Showing 1 changed file with 17 additions and 48 deletions.
65 changes: 17 additions & 48 deletions src/daft-dsl/src/optimization.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
use std::collections::HashMap;

use common_error::DaftResult;
use common_treenode::{
RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion,
};
use common_treenode::{Transformed, TreeNode, VisitRecursion};

use super::expr::Expr;

struct RequiredColumnVisitor {
required: Vec<String>,
}

impl TreeNodeVisitor for RequiredColumnVisitor {
type N = Expr;
fn pre_visit(&mut self, node: &Self::N) -> DaftResult<VisitRecursion> {
if let Expr::Column(name) = node {
self.required.push(name.as_ref().into());
};
Ok(VisitRecursion::Continue)
}
}

pub fn get_required_columns(e: &Expr) -> Vec<String> {
let mut visitor = RequiredColumnVisitor { required: vec![] };
e.visit(&mut visitor)
.expect("Error occurred when visiting for required columns");
visitor.required
let mut cols = vec![];
e.apply(&mut |expr| {
if let Expr::Column(name) = expr {
cols.push(name.as_ref().into());
}
Ok(VisitRecursion::Continue)
})
.expect("Error occurred when visiting for required columns");
cols
}

pub fn requires_computation(e: &Expr) -> bool {
Expand All @@ -43,33 +31,14 @@ pub fn requires_computation(e: &Expr) -> bool {
}
}

struct ColumnExpressionRewriter<'a> {
mapping: &'a HashMap<String, Expr>,
}

impl<'a> TreeNodeRewriter for ColumnExpressionRewriter<'a> {
type N = Expr;
fn pre_visit(&mut self, node: &Self::N) -> DaftResult<RewriteRecursion> {
if let Expr::Column(name) = node && self.mapping.contains_key(name.as_ref()) {
Ok(RewriteRecursion::Continue)
} else {
Ok(RewriteRecursion::Skip)
}
}
fn mutate(&mut self, node: Self::N) -> DaftResult<Self::N> {
if let Expr::Column(ref name) = node && let Some(tgt) = self.mapping.get(name.as_ref()){
Ok(tgt.clone())
} else {
Ok(node)
}
}
}

pub fn replace_columns_with_expressions(expr: &Expr, replace_map: &HashMap<String, Expr>) -> Expr {
let mut column_rewriter = ColumnExpressionRewriter {
mapping: replace_map,
};
expr.clone()
.rewrite(&mut column_rewriter)
.transform(&|e| {
if let Expr::Column(ref name) = e && let Some(tgt) = replace_map.get(name.as_ref()) {
Ok(Transformed::Yes(tgt.clone()))
} else {
Ok(Transformed::No(e))
}
})
.expect("Error occurred when rewriting column expressions")
}

0 comments on commit 991eb99

Please sign in to comment.