Skip to content

Commit

Permalink
wip part filter translation
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Dec 1, 2023
1 parent 33bdfcb commit 020be51
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 46 deletions.
29 changes: 29 additions & 0 deletions src/daft-dsl/src/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::collections::HashMap;

use common_treenode::{Transformed, TreeNode, VisitRecursion};

use crate::Operator;

use super::expr::Expr;

pub fn get_required_columns(e: &Expr) -> Vec<String> {
Expand Down Expand Up @@ -42,3 +44,30 @@ pub fn replace_columns_with_expressions(expr: &Expr, replace_map: &HashMap<Strin
})
.expect("Error occurred when rewriting column expressions")
}

pub fn split_conjuction(expr: &Expr) -> Vec<&Expr> {
let mut splits = vec![];
_split_conjuction(expr, &mut splits);
splits
}

fn _split_conjuction<'a>(expr: &'a Expr, out_exprs: &mut Vec<&'a Expr>) {
match expr {
Expr::BinaryOp {
op: Operator::And,
left,
right,
} => {
_split_conjuction(left, out_exprs);
_split_conjuction(right, out_exprs);
}
Expr::Alias(inner_expr, ..) => _split_conjuction(inner_expr, out_exprs),
_ => {
out_exprs.push(expr);
}
}
}

pub fn conjuct(exprs: Vec<Expr>) -> Option<Expr> {
exprs.into_iter().reduce(|acc, expr| acc.and(&expr))
}
1 change: 0 additions & 1 deletion src/daft-plan/src/optimization/rules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod push_down_filter;
mod push_down_limit;
mod push_down_projection;
mod rule;
mod utils;

pub use drop_repartition::DropRepartition;
pub use push_down_filter::PushDownFilter;
Expand Down
10 changes: 7 additions & 3 deletions src/daft-plan/src/optimization/rules/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use common_error::DaftResult;
use daft_dsl::{
col,
optimization::{get_required_columns, replace_columns_with_expressions},
optimization::{get_required_columns, replace_columns_with_expressions, conjuct, split_conjuction},
Expr,
};
use daft_scan::ScanExternalInfo;
Expand All @@ -18,7 +18,6 @@ use crate::{
};

use super::{
utils::{conjuct, split_conjuction},
ApplyOrder, OptimizerRule, Transformed,
};

Expand Down Expand Up @@ -55,7 +54,12 @@ impl OptimizerRule for PushDownFilter {
let pred = &filter.predicate;

let new_filters = if let Some(filters) = &ext_info.pushdowns().filters {
// TODO ONLY PUSH IF DOESNT EXIST
let pred_sem_id = pred.semantic_id(&output_schema);
for f in filters.iter() {
if f.semantic_id(&output_schema) == pred_sem_id {
return Ok(Transformed::No(plan));
}
}
let mut filters = filters.as_ref().clone();
filters.push(pred.clone().into());
filters
Expand Down
28 changes: 0 additions & 28 deletions src/daft-plan/src/optimization/rules/utils.rs

This file was deleted.

46 changes: 33 additions & 13 deletions src/daft-scan/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::{collections::HashMap, sync::Arc};
use std::{collections::{HashMap, HashSet}, sync::Arc};

use common_error::DaftResult;
use daft_dsl::{
col,
common_treenode::{Transformed, TreeNode},
common_treenode::{Transformed, TreeNode, VisitRecursion},
functions::{partitioning, FunctionExpr},
null_lit, Expr, Operator,
null_lit, Expr, Operator, optimization::{split_conjuction, conjuct},
};

use crate::{PartitionField, PartitionTransform};
Expand Down Expand Up @@ -36,7 +36,7 @@ fn apply_partitioning_expr(expr: Expr, tfm: PartitionTransform) -> Option<Expr>
pub fn rewrite_predicate_for_partitioning(
predicate: Expr,
pfields: &[PartitionField],
) -> DaftResult<Expr> {
) -> DaftResult<Vec<Expr>> {
if pfields.is_empty() {
todo!("no predicate")
}
Expand All @@ -52,27 +52,47 @@ pub fn rewrite_predicate_for_partitioning(
map
};

predicate.transform(&|expr| {
let with_part_cols = predicate.transform(&|expr| {
use Operator::*;
match expr {
Expr::BinaryOp {
op,
ref left, ref right } if matches!(op, Eq | NotEq | Lt | LtEq | Gt | GtEq)=> {

if let Expr::Column(col_name) = left.as_ref() {
if let Some(pfield) = source_to_pfield.get(col_name.as_ref()) {
if let Some(tfm) = pfield.transform && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), tfm) {
return Ok(Transformed::Yes(Expr::BinaryOp { op, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() }));
}
if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) {
if let Some(tfm) = pfield.transform && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), tfm) {
return Ok(Transformed::Yes(Expr::BinaryOp { op, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() }));
}
Ok(Transformed::No(expr))
} else if let Expr::Column(col_name) = right.as_ref() {
} else if let Expr::Column(col_name) = right.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) {
if let Some(tfm) = pfield.transform && let Some(new_expr) = apply_partitioning_expr(left.as_ref().clone(), tfm) {
return Ok(Transformed::Yes(Expr::BinaryOp { op, left: new_expr.into(), right: col(pfield.field.name.as_str()).into() }));
}
Ok(Transformed::No(expr))
} else {
Ok(Transformed::No(expr))
}
},
Expr::IsNull(ref expr) if let Expr::Column(col_name) = expr.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) => {
Ok(Transformed::Yes(Expr::IsNull(col(pfield.field.name.as_str()).into())))
},
_ => Ok(Transformed::No(expr))
}
})
})?;

let p_keys = HashSet::<&str>::from_iter(pfields.iter().map(|p| p.field.name.as_ref()));

let split = split_conjuction(&with_part_cols);
let filtered = split.into_iter().filter(|p| {
let mut keep = true;
p.apply(&mut |e| {
if let Expr::Column(col_name) = e && !p_keys.contains(col_name.as_ref()) {
keep = false;

}
Ok(VisitRecursion::Continue)
}).unwrap();
keep
}).cloned().collect::<Vec<_>>();

Ok(filtered)
}
6 changes: 5 additions & 1 deletion src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ partitioning_keys:\n",
p.as_ref().clone(),
self.partitioning_keys.as_slice(),
)?;
println!("before {} after {}", p, transformed);
println!("before {}", p);
for t in transformed {
println!(" {t}");
}

}
}

Expand Down
3 changes: 3 additions & 0 deletions tests/integration/iceberg/test_table_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def test_daft_iceberg_table_show(table_name, local_iceberg_catalog):
def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog):
tab = local_iceberg_catalog.load_table(f"default.{table_name}")
df = daft.read_iceberg(tab)
import ipdb

ipdb.set_trace()
df.collect()
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
Expand Down

0 comments on commit 020be51

Please sign in to comment.