Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
desmondcheongzx committed Dec 20, 2024
1 parent b2d9c2d commit 4e6d79f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,54 @@ use crate::{
LogicalPlan, LogicalPlanBuilder, LogicalPlanRef,
};

// TODO(desmond): In the future these trees should keep track of current cost estimates.
/// A JoinOrderTree is a tree that describes a join order between relations, which can range from left deep trees
/// to bushy trees. A relations in a JoinOrderTree contain IDs instead of logical plan references. An ID's
/// corresponding logical plan reference can be found by consulting the JoinAdjList that was used to produce the
/// given JoinOrderTree.
///
/// TODO(desmond): In the future these trees should keep track of current cost estimates.
#[derive(Clone, Debug)]
pub(super) enum JoinOrderTree {
Relation(usize), // (id).
Join(Box<JoinOrderTree>, Box<JoinOrderTree>, Vec<usize>), // (subtree, subtree, nodes involved).
Relation(usize), // (ID).
Join(Box<JoinOrderTree>, Box<JoinOrderTree>), // (subtree, subtree).
}

impl JoinOrderTree {
pub(super) fn join(self: Box<Self>, right: Box<JoinOrderTree>) -> Box<Self> {
let mut nodes = self.nodes();
nodes.append(&mut right.nodes());
Box::new(JoinOrderTree::Join(self, right, nodes))
}

pub(super) fn nodes(&self) -> Vec<usize> {
match self {
Self::Relation(id) => vec![*id],
Self::Join(_, _, nodes) => nodes.clone(),
}
Box::new(JoinOrderTree::Join(self, right))
}

// Helper function that checks if the join order tree contains a given id.
pub(super) fn contains(&self, target_id: usize) -> bool {
match self {
Self::Relation(id) => *id == target_id,
Self::Join(left, right, _) => left.contains(target_id) || right.contains(target_id),
Self::Join(left, right) => left.contains(target_id) || right.contains(target_id),
}
}

pub(super) fn iter(&self) -> JoinOrderTreeIterator {
JoinOrderTreeIterator { stack: vec![self] }
}
}

pub(super) struct JoinOrderTreeIterator<'a> {
stack: Vec<&'a JoinOrderTree>,
}

impl<'a> Iterator for JoinOrderTreeIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.stack.pop() {
match node {
JoinOrderTree::Relation(id) => return Some(*id),
JoinOrderTree::Join(left, right) => {
self.stack.push(left);
self.stack.push(right);
}
}
}
None
}
}

Expand Down Expand Up @@ -130,7 +151,7 @@ impl JoinAdjList {
}
}

pub(super) fn get_plan_id(&mut self, plan: &LogicalPlanRef) -> usize {
pub(super) fn get_or_create_plan_id(&mut self, plan: &LogicalPlanRef) -> usize {
let ptr = Arc::as_ptr(plan);
if let Some(id) = self.plan_to_id.get(&ptr) {
*id
Expand All @@ -143,14 +164,12 @@ impl JoinAdjList {
}
}

fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) {
// TODO(desmond): We should also keep track of projections that we need to do.
let join_condition = JoinCondition {
left_on: left.final_name.clone(),
right_on: right.final_name.clone(),
};
let left_id = self.get_plan_id(&left.plan);
let right_id = self.get_plan_id(&right.plan);
fn add_join_condition(
&mut self,
left_id: usize,
right_id: usize,
join_condition: JoinCondition,
) {
if let Some(neighbors) = self.edges.get_mut(&left_id) {
if let Some(join_conditions) = neighbors.get_mut(&right_id) {
join_conditions.push(join_condition);
Expand All @@ -164,16 +183,26 @@ impl JoinAdjList {
}
}

fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) {
let join_condition = JoinCondition {
left_on: left.final_name.clone(),
right_on: right.final_name.clone(),
};
let left_id = self.get_or_create_plan_id(&left.plan);
let right_id = self.get_or_create_plan_id(&right.plan);
self.add_join_condition(left_id, right_id, join_condition);
}

pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) {
self.add_unidirectional_edge(&node1, &node2);
self.add_unidirectional_edge(&node2, &node1);
}

pub(super) fn connected(&self, left_nodes: &Vec<usize>, right_nodes: &Vec<usize>) -> bool {
for left_node in left_nodes {
if let Some(neighbors) = self.edges.get(left_node) {
for right_node in right_nodes {
if let Some(_) = neighbors.get(right_node) {
pub(super) fn connected_join_trees(&self, left: &JoinOrderTree, right: &JoinOrderTree) -> bool {
for left_node in left.iter() {
if let Some(neighbors) = self.edges.get(&left_node) {
for right_node in right.iter() {
if let Some(_) = neighbors.get(&right_node) {
return true;
}
}
Expand Down Expand Up @@ -271,14 +300,21 @@ impl JoinGraph {
false
}

fn get_node_by_id(&self, id: usize) -> &LogicalPlanRef {
self.adj_list
.id_to_plan
.get(&id)
.expect("Tried to retrieve a plan from the join graph with an invalid ID")
}

/// Helper function that loosely checks if a given edge (represented by a simple string)
/// exists in the current graph.
pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool {
let mut edge_strings = HashSet::new();
for (left_id, neighbors) in &self.adj_list.edges {
for (right_id, join_conds) in neighbors {
let left = self.adj_list.id_to_plan.get(left_id).unwrap();
let right = self.adj_list.id_to_plan.get(right_id).unwrap();
let left = self.get_node_by_id(*left_id);
let right = self.get_node_by_id(*right_id);
for join_cond in join_conds {
edge_strings.insert(format!(
"{}({}) <-> {}({})",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[cfg(test)]
mod join_graph;
#[cfg(test)]
mod naive_join_order;
mod naive_left_deep_join_order;
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::join_graph::{JoinGraph, JoinOrderTree, JoinOrderer};

pub(crate) struct NaiveJoinOrderer {}
pub(crate) struct NaiveLeftDeepJoinOrderer {}

impl NaiveJoinOrderer {
impl NaiveLeftDeepJoinOrderer {
fn extend_order(
graph: &JoinGraph,
current_order: Box<JoinOrderTree>,
Expand All @@ -13,10 +13,7 @@ impl NaiveJoinOrderer {
}
for (index, candidate_node_id) in available.iter().enumerate() {
let right = Box::new(JoinOrderTree::Relation(*candidate_node_id));
if graph
.adj_list
.connected(&current_order.nodes(), &right.nodes())
{
if graph.adj_list.connected_join_trees(&current_order, &right) {
let new_order = current_order.join(right);
available.remove(index);
return Self::extend_order(graph, new_order, available);
Expand All @@ -26,7 +23,7 @@ impl NaiveJoinOrderer {
}
}

impl JoinOrderer for NaiveJoinOrderer {
impl JoinOrderer for NaiveLeftDeepJoinOrderer {
fn order(&self, graph: &JoinGraph) -> Box<JoinOrderTree> {
let available: Vec<usize> = (1..graph.adj_list.max_id).collect();
// Take a starting order of the node with id 0.
Expand All @@ -39,9 +36,9 @@ impl JoinOrderer for NaiveJoinOrderer {
mod tests {
use common_scan_info::Pushdowns;
use daft_schema::{dtype::DataType, field::Field};
use rand::{seq::SliceRandom, Rng};
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};

use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveJoinOrderer};
use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveLeftDeepJoinOrderer};
use crate::{
optimization::rules::reorder_joins::join_graph::{JoinAdjList, JoinNode},
test::{dummy_scan_node_with_pushdowns, dummy_scan_operator_with_size},
Expand Down Expand Up @@ -99,7 +96,7 @@ mod tests {
(1, 2), // node_b <-> node_c
(2, 3), // node_c <-> node_d
];
create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {});
create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {});
}

pub struct UnionFind {
Expand Down Expand Up @@ -139,7 +136,7 @@ mod tests {
}

fn create_random_connected_graph(num_nodes: usize) -> Vec<(usize, usize)> {
let mut rng = rand::thread_rng();
let mut rng = StdRng::seed_from_u64(0);
let mut edges = Vec::new();
let mut uf = UnionFind::create(num_nodes);

Expand Down Expand Up @@ -176,6 +173,6 @@ mod tests {
.map(|i| format!("node_{}", i))
.collect();
let edges = create_random_connected_graph(NUM_RANDOM_NODES);
create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {});
create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {});
}
}

0 comments on commit 4e6d79f

Please sign in to comment.