From 4e6d79ff1e756ea595a91a6fc27186c179cac99d Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Fri, 20 Dec 2024 16:18:49 +0800 Subject: [PATCH] Address comments --- .../rules/reorder_joins/join_graph.rs | 96 +++++++++++++------ .../optimization/rules/reorder_joins/mod.rs | 2 +- ...order.rs => naive_left_deep_join_order.rs} | 21 ++-- 3 files changed, 76 insertions(+), 43 deletions(-) rename src/daft-logical-plan/src/optimization/rules/reorder_joins/{naive_join_order.rs => naive_left_deep_join_order.rs} (90%) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index d9c65e6e3b..e7cd907657 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -13,33 +13,54 @@ use crate::{ LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; -// TODO(desmond): In the future these trees should keep track of current cost estimates. +/// A JoinOrderTree is a tree that describes a join order between relations, which can range from left deep trees +/// to bushy trees. A relations in a JoinOrderTree contain IDs instead of logical plan references. An ID's +/// corresponding logical plan reference can be found by consulting the JoinAdjList that was used to produce the +/// given JoinOrderTree. +/// +/// TODO(desmond): In the future these trees should keep track of current cost estimates. #[derive(Clone, Debug)] pub(super) enum JoinOrderTree { - Relation(usize), // (id). - Join(Box, Box, Vec), // (subtree, subtree, nodes involved). + Relation(usize), // (ID). + Join(Box, Box), // (subtree, subtree). } impl JoinOrderTree { pub(super) fn join(self: Box, right: Box) -> Box { - let mut nodes = self.nodes(); - nodes.append(&mut right.nodes()); - Box::new(JoinOrderTree::Join(self, right, nodes)) - } - - pub(super) fn nodes(&self) -> Vec { - match self { - Self::Relation(id) => vec![*id], - Self::Join(_, _, nodes) => nodes.clone(), - } + Box::new(JoinOrderTree::Join(self, right)) } // Helper function that checks if the join order tree contains a given id. pub(super) fn contains(&self, target_id: usize) -> bool { match self { Self::Relation(id) => *id == target_id, - Self::Join(left, right, _) => left.contains(target_id) || right.contains(target_id), + Self::Join(left, right) => left.contains(target_id) || right.contains(target_id), + } + } + + pub(super) fn iter(&self) -> JoinOrderTreeIterator { + JoinOrderTreeIterator { stack: vec![self] } + } +} + +pub(super) struct JoinOrderTreeIterator<'a> { + stack: Vec<&'a JoinOrderTree>, +} + +impl<'a> Iterator for JoinOrderTreeIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + while let Some(node) = self.stack.pop() { + match node { + JoinOrderTree::Relation(id) => return Some(*id), + JoinOrderTree::Join(left, right) => { + self.stack.push(left); + self.stack.push(right); + } + } } + None } } @@ -130,7 +151,7 @@ impl JoinAdjList { } } - pub(super) fn get_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { + pub(super) fn get_or_create_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { let ptr = Arc::as_ptr(plan); if let Some(id) = self.plan_to_id.get(&ptr) { *id @@ -143,14 +164,12 @@ impl JoinAdjList { } } - fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { - // TODO(desmond): We should also keep track of projections that we need to do. - let join_condition = JoinCondition { - left_on: left.final_name.clone(), - right_on: right.final_name.clone(), - }; - let left_id = self.get_plan_id(&left.plan); - let right_id = self.get_plan_id(&right.plan); + fn add_join_condition( + &mut self, + left_id: usize, + right_id: usize, + join_condition: JoinCondition, + ) { if let Some(neighbors) = self.edges.get_mut(&left_id) { if let Some(join_conditions) = neighbors.get_mut(&right_id) { join_conditions.push(join_condition); @@ -164,16 +183,26 @@ impl JoinAdjList { } } + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { + let join_condition = JoinCondition { + left_on: left.final_name.clone(), + right_on: right.final_name.clone(), + }; + let left_id = self.get_or_create_plan_id(&left.plan); + let right_id = self.get_or_create_plan_id(&right.plan); + self.add_join_condition(left_id, right_id, join_condition); + } + pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { self.add_unidirectional_edge(&node1, &node2); self.add_unidirectional_edge(&node2, &node1); } - pub(super) fn connected(&self, left_nodes: &Vec, right_nodes: &Vec) -> bool { - for left_node in left_nodes { - if let Some(neighbors) = self.edges.get(left_node) { - for right_node in right_nodes { - if let Some(_) = neighbors.get(right_node) { + pub(super) fn connected_join_trees(&self, left: &JoinOrderTree, right: &JoinOrderTree) -> bool { + for left_node in left.iter() { + if let Some(neighbors) = self.edges.get(&left_node) { + for right_node in right.iter() { + if let Some(_) = neighbors.get(&right_node) { return true; } } @@ -271,14 +300,21 @@ impl JoinGraph { false } + fn get_node_by_id(&self, id: usize) -> &LogicalPlanRef { + self.adj_list + .id_to_plan + .get(&id) + .expect("Tried to retrieve a plan from the join graph with an invalid ID") + } + /// Helper function that loosely checks if a given edge (represented by a simple string) /// exists in the current graph. pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { let mut edge_strings = HashSet::new(); for (left_id, neighbors) in &self.adj_list.edges { for (right_id, join_conds) in neighbors { - let left = self.adj_list.id_to_plan.get(left_id).unwrap(); - let right = self.adj_list.id_to_plan.get(right_id).unwrap(); + let left = self.get_node_by_id(*left_id); + let right = self.get_node_by_id(*right_id); for join_cond in join_conds { edge_strings.insert(format!( "{}({}) <-> {}({})", diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index c8644b620e..762d58a4a8 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,4 +1,4 @@ #[cfg(test)] mod join_graph; #[cfg(test)] -mod naive_join_order; +mod naive_left_deep_join_order; diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs similarity index 90% rename from src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs rename to src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs index 0fba6e4fa3..c0b3ab6634 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs @@ -1,8 +1,8 @@ use super::join_graph::{JoinGraph, JoinOrderTree, JoinOrderer}; -pub(crate) struct NaiveJoinOrderer {} +pub(crate) struct NaiveLeftDeepJoinOrderer {} -impl NaiveJoinOrderer { +impl NaiveLeftDeepJoinOrderer { fn extend_order( graph: &JoinGraph, current_order: Box, @@ -13,10 +13,7 @@ impl NaiveJoinOrderer { } for (index, candidate_node_id) in available.iter().enumerate() { let right = Box::new(JoinOrderTree::Relation(*candidate_node_id)); - if graph - .adj_list - .connected(¤t_order.nodes(), &right.nodes()) - { + if graph.adj_list.connected_join_trees(¤t_order, &right) { let new_order = current_order.join(right); available.remove(index); return Self::extend_order(graph, new_order, available); @@ -26,7 +23,7 @@ impl NaiveJoinOrderer { } } -impl JoinOrderer for NaiveJoinOrderer { +impl JoinOrderer for NaiveLeftDeepJoinOrderer { fn order(&self, graph: &JoinGraph) -> Box { let available: Vec = (1..graph.adj_list.max_id).collect(); // Take a starting order of the node with id 0. @@ -39,9 +36,9 @@ impl JoinOrderer for NaiveJoinOrderer { mod tests { use common_scan_info::Pushdowns; use daft_schema::{dtype::DataType, field::Field}; - use rand::{seq::SliceRandom, Rng}; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; - use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveJoinOrderer}; + use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveLeftDeepJoinOrderer}; use crate::{ optimization::rules::reorder_joins::join_graph::{JoinAdjList, JoinNode}, test::{dummy_scan_node_with_pushdowns, dummy_scan_operator_with_size}, @@ -99,7 +96,7 @@ mod tests { (1, 2), // node_b <-> node_c (2, 3), // node_c <-> node_d ]; - create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {}); } pub struct UnionFind { @@ -139,7 +136,7 @@ mod tests { } fn create_random_connected_graph(num_nodes: usize) -> Vec<(usize, usize)> { - let mut rng = rand::thread_rng(); + let mut rng = StdRng::seed_from_u64(0); let mut edges = Vec::new(); let mut uf = UnionFind::create(num_nodes); @@ -176,6 +173,6 @@ mod tests { .map(|i| format!("node_{}", i)) .collect(); let edges = create_random_connected_graph(NUM_RANDOM_NODES); - create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {}); } }