From bc17ae90837980480ef35b716c208d0f7c871f7f Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Mon, 14 Oct 2024 10:51:45 -0400 Subject: [PATCH 1/5] core: use union-find to merge groups Also improve tpch sqlplannertest run from ~89s to ~9s. Signed-off-by: Yuchen Liang --- optd-core/src/cascades/memo.rs | 38 +-- optd-core/src/cascades/memo/disjoint_set.rs | 243 ++++++++++++++++++++ optd-sqlplannertest/tests/tpch.planner.sql | 116 +++++----- 3 files changed, 321 insertions(+), 76 deletions(-) create mode 100644 optd-core/src/cascades/memo/disjoint_set.rs diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index a194f71c..2503e697 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -1,3 +1,5 @@ +mod disjoint_set; + use std::{ collections::{hash_map::Entry, HashMap, HashSet}, fmt::Display, @@ -5,6 +7,7 @@ use std::{ }; use anyhow::{bail, Result}; +use disjoint_set::{DisjointSet, UnionFind}; use itertools::Itertools; use std::any::Any; @@ -78,7 +81,7 @@ pub struct Memo { expr_node_to_expr_id: HashMap, ExprId>, groups: HashMap, group_expr_counter: usize, - merged_groups: HashMap, + disjoint_groups: DisjointSet, property_builders: Arc<[Box>]>, } @@ -90,7 +93,7 @@ impl Memo { expr_node_to_expr_id: HashMap::new(), groups: HashMap::new(), group_expr_counter: 0, - merged_groups: HashMap::new(), + disjoint_groups: DisjointSet::new(), property_builders, } } @@ -99,6 +102,7 @@ impl Memo { fn next_group_id(&mut self) -> ReducedGroupId { let id = self.group_expr_counter; self.group_expr_counter += 1; + self.disjoint_groups.add(GroupId(id)); ReducedGroupId(id) } @@ -118,18 +122,20 @@ impl Memo { return group_a; } + let [rep, other] = self + .disjoint_groups + .union(&group_a.as_group_id(), &group_b.as_group_id()) + .unwrap(); + // Copy all expressions from group a to group b - let group_a_exprs = self.get_all_exprs_in_group(group_a.as_group_id()); - for expr_id in group_a_exprs { + let other_exprs = self.get_all_exprs_in_group(other); + for expr_id in other_exprs { let expr_node = self.expr_id_to_expr_node.get(&expr_id).unwrap(); - self.add_expr_to_group(expr_id, group_b, expr_node.as_ref().clone()); + self.add_expr_to_group(expr_id, ReducedGroupId(rep.0), expr_node.as_ref().clone()); } - self.merged_groups - .insert(group_a.as_group_id(), group_b.as_group_id()); - // Remove all expressions from group a (so we don't accidentally access it) - self.clear_exprs_in_group(group_a); + self.clear_exprs_in_group(ReducedGroupId(other.0)); group_b } @@ -145,11 +151,9 @@ impl Memo { self.expr_id_to_group_id[&expr_id] } - fn get_reduced_group_id(&self, mut group_id: GroupId) -> ReducedGroupId { - while let Some(next_group_id) = self.merged_groups.get(&group_id) { - group_id = *next_group_id; - } - ReducedGroupId(group_id.0) + fn get_reduced_group_id(&self, group_id: GroupId) -> ReducedGroupId { + let reduced = self.disjoint_groups.find(&group_id).unwrap(); + ReducedGroupId(reduced.0) } /// Add or get an expression into the memo, returns the group id and the expr id. If `GroupId` is `None`, @@ -164,10 +168,8 @@ impl Memo { self.merge_group(grp_a, grp_b); }; - let (group_id, expr_id) = self.add_new_group_expr_inner( - rel_node, - add_to_group_id.map(|x| self.get_reduced_group_id(x)), - ); + let add_to_group_id = add_to_group_id.map(|x| self.get_reduced_group_id(x)); + let (group_id, expr_id) = self.add_new_group_expr_inner(rel_node, add_to_group_id); (group_id.as_group_id(), expr_id) } diff --git a/optd-core/src/cascades/memo/disjoint_set.rs b/optd-core/src/cascades/memo/disjoint_set.rs new file mode 100644 index 00000000..7989fc3d --- /dev/null +++ b/optd-core/src/cascades/memo/disjoint_set.rs @@ -0,0 +1,243 @@ +use std::{ + collections::HashMap, + fmt::Debug, + hash::Hash, + ops::{Deref, DerefMut}, + sync::{atomic::AtomicUsize, Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, +}; + +/// A data structure for efficiently maintaining disjoint sets of `T`. +pub struct DisjointSet { + /// Mapping from node to its parent. + /// + /// # Design + /// We use a `mutex` instead of reader-writer lock so that + /// we always need write permission to perform `path compression` + /// during "finds". + /// + /// Alternatively, we could do no path compression at `find`, + /// and only do path compression when we were doing union. + node_parents: Arc>>, + /// Number of disjoint sets. + num_sets: AtomicUsize, +} + +pub trait UnionFind +where + T: Ord, +{ + /// Unions the set containing `a` and the set containing `b`. + /// Returns the new representative followed by the other node, + /// or `None` if one the node is not present. + /// + /// The smaller representative is selected as the new representative. + fn union(&self, a: &T, b: &T) -> Option<[T; 2]>; + + /// Gets the representative node of the set that `node` is in. + /// Path compression is performed while finding the representative. + fn find_path_compress(&self, node: &T) -> Option; + + /// Gets the representative node of the set that `node` is in. + fn find(&self, node: &T) -> Option; +} + +impl DisjointSet +where + T: Ord + Hash + Copy, +{ + pub fn new() -> Self { + DisjointSet { + node_parents: Arc::new(RwLock::new(HashMap::new())), + num_sets: AtomicUsize::new(0), + } + } + + pub fn size(&self) -> usize { + let g = self.node_parents.read().unwrap(); + g.len() + } + + pub fn num_sets(&self) -> usize { + self.num_sets.load(std::sync::atomic::Ordering::Relaxed) + } + + pub fn add(&mut self, node: T) { + use std::collections::hash_map::Entry; + + let mut g = self.node_parents.write().unwrap(); + if let Entry::Vacant(entry) = g.entry(node) { + entry.insert(node); + drop(g); + self.num_sets + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + } + + fn get_parent(g: &impl Deref>, node: &T) -> Option { + g.get(node).copied() + } + + fn set_parent(g: &mut impl DerefMut>, node: T, parent: T) { + g.insert(node, parent); + } + + /// Recursively find the parent of the `node` until reaching the representative of the set. + /// A node is the representative if the its parent is the node itself. + /// + /// We utilize "path compression" to shorten the height of parent forest. + fn find_path_compress_inner( + g: &mut RwLockWriteGuard<'_, HashMap>, + node: &T, + ) -> Option { + let mut parent = Self::get_parent(g, node)?; + + if *node != parent { + parent = Self::find_path_compress_inner(g, &parent)?; + + // Path compression. + Self::set_parent(g, *node, parent); + } + + Some(parent) + } + + /// Recursively find the parent of the `node` until reaching the representative of the set. + /// A node is the representative if the its parent is the node itself. + fn find_inner(g: &RwLockReadGuard<'_, HashMap>, node: &T) -> Option { + let mut parent = Self::get_parent(g, node)?; + + if *node != parent { + parent = Self::find_inner(g, &parent)?; + } + + Some(parent) + } +} + +impl UnionFind for DisjointSet +where + T: Ord + Hash + Copy + Debug, +{ + fn union(&self, a: &T, b: &T) -> Option<[T; 2]> { + use std::cmp::Ordering; + + // Gets the represenatives for set containing `a`. + let a_rep = self.find_path_compress(&a)?; + + // Gets the represenatives for set containing `b`. + let b_rep = self.find_path_compress(&b)?; + + let mut g = self.node_parents.write().unwrap(); + + // Node with smaller value becomes the representative. + let res = match a_rep.cmp(&b_rep) { + Ordering::Less => { + Self::set_parent(&mut g, b_rep, a_rep); + self.num_sets + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + [a_rep, b_rep] + } + Ordering::Greater => { + Self::set_parent(&mut g, a_rep, b_rep); + self.num_sets + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + [b_rep, a_rep] + } + Ordering::Equal => [a_rep, b_rep], + }; + Some(res) + } + + /// See [`Self::find_inner`] for implementation detail. + fn find_path_compress(&self, node: &T) -> Option { + let mut g = self.node_parents.write().unwrap(); + Self::find_path_compress_inner(&mut g, node) + } + + /// See [`Self::find_inner`] for implementation detail. + fn find(&self, node: &T) -> Option { + let g = self.node_parents.read().unwrap(); + Self::find_inner(&g, node) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + fn minmax(v1: T, v2: T) -> [T; 2] + where + T: Ord, + { + if v1 <= v2 { + [v1, v2] + } else { + [v2, v1] + } + } + + fn test_union_find(inputs: Vec) + where + T: Ord + Hash + Copy + Debug, + { + let mut set = DisjointSet::new(); + + for input in inputs.iter() { + set.add(*input); + } + + for input in inputs.iter() { + let rep = set.find(input); + assert_eq!( + rep, + Some(*input), + "representive should be node itself for singleton" + ); + } + assert_eq!(set.size(), 10); + assert_eq!(set.num_sets(), 10); + + for input in inputs.iter() { + set.union(input, input).unwrap(); + let rep = set.find(input); + assert_eq!( + rep, + Some(*input), + "representive should be node itself for singleton" + ); + } + assert_eq!(set.size(), 10); + assert_eq!(set.num_sets(), 10); + + for (x, y) in inputs.iter().zip(inputs.iter().rev()) { + let y_rep = set.find(&y).unwrap(); + let [rep, other] = set.union(x, y).expect(&format!( + "union should be successful between {:?} and {:?}", + x, y, + )); + if rep != other { + assert_eq!([rep, other], minmax(*x, y_rep)); + } + } + + for (x, y) in inputs.iter().zip(inputs.iter().rev()) { + let rep = set.find(x); + + let expected = x.min(y); + assert_eq!(rep, Some(*expected)); + } + assert_eq!(set.size(), 10); + assert_eq!(set.num_sets(), 5); + } + + #[test] + fn test_union_find_i32() { + test_union_find(Vec::from_iter(0..10)); + } + + #[test] + fn test_union_find_group() { + test_union_find(Vec::from_iter((0..10).map(|i| crate::cascades::GroupId(i)))); + } +} diff --git a/optd-sqlplannertest/tests/tpch.planner.sql b/optd-sqlplannertest/tests/tpch.planner.sql index 8bf88051..e7a86ef9 100644 --- a/optd-sqlplannertest/tests/tpch.planner.sql +++ b/optd-sqlplannertest/tests/tpch.planner.sql @@ -627,29 +627,29 @@ PhysicalSort │ ├── Cast { cast_to: Decimal128(20, 0), expr: 1(i64) } │ └── #23 ├── groups: [ #41 ] - └── PhysicalHashJoin { join_type: Inner, left_keys: [ #19, #3 ], right_keys: [ #0, #3 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #1 ] } - │ ├── PhysicalScan { table: customer } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalFilter - │ │ ├── cond:And - │ │ │ ├── Geq - │ │ │ │ ├── #4 - │ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" } - │ │ │ └── Lt - │ │ │ ├── #4 - │ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" } - │ │ └── PhysicalScan { table: orders } - │ └── PhysicalScan { table: lineitem } - └── PhysicalHashJoin { join_type: Inner, left_keys: [ #9 ], right_keys: [ #0 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } - │ ├── PhysicalScan { table: supplier } - │ └── PhysicalScan { table: nation } - └── PhysicalFilter - ├── cond:Eq - │ ├── #1 - │ └── "Asia" - └── PhysicalScan { table: region } + └── PhysicalHashJoin { join_type: Inner, left_keys: [ #42 ], right_keys: [ #0 ] } + ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #36 ], right_keys: [ #0 ] } + │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #19, #3 ], right_keys: [ #0, #3 ] } + │ │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #1 ] } + │ │ │ ├── PhysicalScan { table: customer } + │ │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ │ ├── PhysicalFilter + │ │ │ │ ├── cond:And + │ │ │ │ │ ├── Geq + │ │ │ │ │ │ ├── #4 + │ │ │ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" } + │ │ │ │ │ └── Lt + │ │ │ │ │ ├── #4 + │ │ │ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" } + │ │ │ │ └── PhysicalScan { table: orders } + │ │ │ └── PhysicalScan { table: lineitem } + │ │ └── PhysicalScan { table: supplier } + │ └── PhysicalScan { table: nation } + └── PhysicalFilter + ├── cond:Eq + │ ├── #1 + │ └── "Asia" + └── PhysicalScan { table: region } */ -- TPC-H Q6 @@ -864,12 +864,12 @@ PhysicalSort ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #2 ] } │ │ ├── PhysicalScan { table: supplier } - │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #17 ], right_keys: [ #0 ] } - │ │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ │ │ ├── PhysicalFilter { cond: Between { expr: #10, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } - │ │ │ │ └── PhysicalScan { table: lineitem } - │ │ │ └── PhysicalScan { table: orders } - │ │ └── PhysicalScan { table: customer } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ ├── PhysicalFilter { cond: Between { expr: #10, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } + │ │ │ └── PhysicalScan { table: lineitem } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] } + │ │ ├── PhysicalScan { table: orders } + │ │ └── PhysicalScan { table: customer } │ └── PhysicalScan { table: nation } └── PhysicalScan { table: nation } */ @@ -1033,14 +1033,14 @@ PhysicalSort │ │ │ │ │ └── "ECONOMY ANODIZED STEEL" │ │ │ │ └── PhysicalScan { table: part } │ │ │ └── PhysicalScan { table: supplier } - │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #17 ], right_keys: [ #0 ] } - │ │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ │ │ ├── PhysicalScan { table: lineitem } - │ │ │ └── PhysicalFilter { cond: Between { expr: #4, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } - │ │ │ └── PhysicalScan { table: orders } - │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } - │ │ ├── PhysicalScan { table: customer } - │ │ └── PhysicalScan { table: nation } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ ├── PhysicalScan { table: lineitem } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] } + │ │ ├── PhysicalFilter { cond: Between { expr: #4, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } } } + │ │ │ └── PhysicalScan { table: orders } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } + │ │ ├── PhysicalScan { table: customer } + │ │ └── PhysicalScan { table: nation } │ └── PhysicalScan { table: nation } └── PhysicalFilter ├── cond:Eq @@ -1167,16 +1167,16 @@ PhysicalSort │ ├── #35 │ └── #20 └── PhysicalHashJoin { join_type: Inner, left_keys: [ #12 ], right_keys: [ #0 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } - │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } - │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } - │ │ │ └── PhysicalScan { table: part } - │ │ └── PhysicalScan { table: supplier } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } - │ │ ├── PhysicalScan { table: lineitem } - │ │ └── PhysicalScan { table: partsupp } - │ └── PhysicalScan { table: orders } + ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #16 ], right_keys: [ #0 ] } + │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } + │ │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } + │ │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } + │ │ │ │ └── PhysicalScan { table: part } + │ │ │ └── PhysicalScan { table: supplier } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } + │ │ ├── PhysicalScan { table: lineitem } + │ │ └── PhysicalScan { table: partsupp } + │ └── PhysicalScan { table: orders } └── PhysicalScan { table: nation } */ @@ -1298,16 +1298,16 @@ PhysicalSort │ ├── #35 │ └── #20 └── PhysicalHashJoin { join_type: Inner, left_keys: [ #12 ], right_keys: [ #0 ] } - ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } - │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } - │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } - │ │ │ └── PhysicalScan { table: part } - │ │ └── PhysicalScan { table: supplier } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } - │ │ ├── PhysicalScan { table: lineitem } - │ │ └── PhysicalScan { table: partsupp } - │ └── PhysicalScan { table: orders } + ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #16 ], right_keys: [ #0 ] } + │ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #9, #0 ], right_keys: [ #2, #1 ] } + │ │ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true } + │ │ │ ├── PhysicalFilter { cond: Like { expr: #1, pattern: "%green%", negated: false, case_insensitive: false } } + │ │ │ │ └── PhysicalScan { table: part } + │ │ │ └── PhysicalScan { table: supplier } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2, #1 ], right_keys: [ #1, #0 ] } + │ │ ├── PhysicalScan { table: lineitem } + │ │ └── PhysicalScan { table: partsupp } + │ └── PhysicalScan { table: orders } └── PhysicalScan { table: nation } */ @@ -2057,7 +2057,7 @@ PhysicalProjection ├── cond:And │ ├── Eq │ │ ├── #2 - │ │ └── #0 + │ │ └── #4 │ └── Lt │ ├── Cast { cast_to: Decimal128(30, 15), expr: #0 } │ └── #3 From 3b83135af88cb040dfd2f7183bdc9e1ae8912362 Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Mon, 14 Oct 2024 14:56:53 -0400 Subject: [PATCH 2/5] remove ReducedGroupId Signed-off-by: Yuchen Liang --- optd-core/src/cascades/memo.rs | 95 ++++++++++------------------------ 1 file changed, 27 insertions(+), 68 deletions(-) diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index 2503e697..5def21ce 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -2,7 +2,6 @@ mod disjoint_set; use std::{ collections::{hash_map::Entry, HashMap, HashSet}, - fmt::Display, sync::Arc, }; @@ -60,28 +59,15 @@ pub(crate) struct Group { pub(crate) properties: Arc<[Box]>, } -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -struct ReducedGroupId(usize); - -impl ReducedGroupId { - pub fn as_group_id(self) -> GroupId { - GroupId(self.0) - } -} - -impl Display for ReducedGroupId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - pub struct Memo { expr_id_to_group_id: HashMap, expr_id_to_expr_node: HashMap>, expr_node_to_expr_id: HashMap, ExprId>, - groups: HashMap, + /// Stores the mapping from "representative" group id to group. + groups: HashMap, group_expr_counter: usize, - disjoint_groups: DisjointSet, + /// Keeps track of disjoint sets of group ids. + disjoint_group_ids: DisjointSet, property_builders: Arc<[Box>]>, } @@ -93,17 +79,16 @@ impl Memo { expr_node_to_expr_id: HashMap::new(), groups: HashMap::new(), group_expr_counter: 0, - disjoint_groups: DisjointSet::new(), + disjoint_group_ids: DisjointSet::new(), property_builders, } } /// Get the next group id. Group id and expr id shares the same counter, so as to make it easier to debug... - fn next_group_id(&mut self) -> ReducedGroupId { - let id = self.group_expr_counter; + fn next_group_id(&mut self) -> GroupId { + let id = GroupId(self.group_expr_counter); self.group_expr_counter += 1; - self.disjoint_groups.add(GroupId(id)); - ReducedGroupId(id) + id } /// Get the next expr id. Group id and expr id shares the same counter, so as to make it easier to debug... @@ -113,47 +98,32 @@ impl Memo { ExprId(id) } - fn merge_group_inner( - &mut self, - group_a: ReducedGroupId, - group_b: ReducedGroupId, - ) -> ReducedGroupId { + pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId { if group_a == group_b { return group_a; } - let [rep, other] = self - .disjoint_groups - .union(&group_a.as_group_id(), &group_b.as_group_id()) - .unwrap(); + let [rep, other] = self.disjoint_group_ids.union(&group_a, &group_b).unwrap(); - // Copy all expressions from group a to group b + // Copy all expressions from group other to its representative let other_exprs = self.get_all_exprs_in_group(other); for expr_id in other_exprs { let expr_node = self.expr_id_to_expr_node.get(&expr_id).unwrap(); - self.add_expr_to_group(expr_id, ReducedGroupId(rep.0), expr_node.as_ref().clone()); + self.add_expr_to_group(expr_id, rep, expr_node.as_ref().clone()); } - // Remove all expressions from group a (so we don't accidentally access it) - self.clear_exprs_in_group(ReducedGroupId(other.0)); + // Remove all expressions from group other (so we don't accidentally access it) + self.clear_exprs_in_group(other); group_b } - pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId { - let group_a_reduced = self.get_reduced_group_id(group_a); - let group_b_reduced = self.get_reduced_group_id(group_b); - self.merge_group_inner(group_a_reduced, group_b_reduced) - .as_group_id() - } - fn get_group_id_of_expr_id(&self, expr_id: ExprId) -> GroupId { self.expr_id_to_group_id[&expr_id] } - fn get_reduced_group_id(&self, group_id: GroupId) -> ReducedGroupId { - let reduced = self.disjoint_groups.find(&group_id).unwrap(); - ReducedGroupId(reduced.0) + fn get_reduced_group_id(&self, group_id: GroupId) -> GroupId { + self.disjoint_group_ids.find(&group_id).unwrap() } /// Add or get an expression into the memo, returns the group id and the expr id. If `GroupId` is `None`, @@ -170,7 +140,7 @@ impl Memo { let add_to_group_id = add_to_group_id.map(|x| self.get_reduced_group_id(x)); let (group_id, expr_id) = self.add_new_group_expr_inner(rel_node, add_to_group_id); - (group_id.as_group_id(), expr_id) + (group_id, expr_id) } pub fn get_expr_info(&self, rel_node: RelNodeRef) -> (GroupId, ExprId) { @@ -225,18 +195,13 @@ impl Memo { props } - fn clear_exprs_in_group(&mut self, group_id: ReducedGroupId) { + fn clear_exprs_in_group(&mut self, group_id: GroupId) { self.groups.remove(&group_id); } /// If group_id exists, it adds expr_id to the existing group /// Otherwise, it creates a new group of that group_id and insert expr_id into the new group - fn add_expr_to_group( - &mut self, - expr_id: ExprId, - group_id: ReducedGroupId, - memo_node: RelMemoNode, - ) { + fn add_expr_to_group(&mut self, expr_id: ExprId, group_id: GroupId, memo_node: RelMemoNode) { if let Entry::Occupied(mut entry) = self.groups.entry(group_id) { let group = entry.get_mut(); group.group_exprs.insert(expr_id); @@ -248,6 +213,7 @@ impl Memo { properties: self.infer_properties(memo_node).into(), }; group.group_exprs.insert(expr_id); + self.disjoint_group_ids.add(group_id); self.groups.insert(group_id, group); } @@ -298,9 +264,8 @@ impl Memo { and make sure if it does not do any transformation, it should return an empty vec!"); } let group_id = self.get_group_id_of_expr_id(new_expr_id); - let group_id = self.get_reduced_group_id(group_id); - self.merge_group_inner(replace_group_id, group_id); + self.merge_group(replace_group_id, group_id); return false; } @@ -316,8 +281,8 @@ impl Memo { fn add_new_group_expr_inner( &mut self, rel_node: RelNodeRef, - add_to_group_id: Option, - ) -> (ReducedGroupId, ExprId) { + add_to_group_id: Option, + ) -> (GroupId, ExprId) { let children_group_ids = rel_node .children .iter() @@ -338,7 +303,7 @@ impl Memo { let group_id = self.get_group_id_of_expr_id(expr_id); let group_id = self.get_reduced_group_id(group_id); if let Some(add_to_group_id) = add_to_group_id { - self.merge_group_inner(add_to_group_id, group_id); + self.merge_group(add_to_group_id, group_id); } return (group_id, expr_id); } @@ -350,8 +315,7 @@ impl Memo { }; self.expr_id_to_expr_node .insert(expr_id, memo_node.clone().into()); - self.expr_id_to_group_id - .insert(expr_id, group_id.as_group_id()); + self.expr_id_to_group_id.insert(expr_id, group_id); self.expr_node_to_expr_id.insert(memo_node.clone(), expr_id); self.add_expr_to_group(expr_id, group_id, memo_node); (group_id, expr_id) @@ -364,7 +328,7 @@ impl Memo { .expr_id_to_group_id .get(&expr_id) .expect("expr not found in group mapping"); - self.get_reduced_group_id(*group_id).as_group_id() + self.get_reduced_group_id(*group_id) } /// Get the memoized representation of a node. @@ -465,12 +429,7 @@ impl Memo { } pub fn get_all_group_ids(&self) -> Vec { - let mut ids = self - .groups - .keys() - .copied() - .map(|x| x.as_group_id()) - .collect_vec(); + let mut ids = self.groups.keys().copied().collect_vec(); ids.sort(); ids } From fd1f49e122a39dd20dde4c91eaaa1e6ba52cadd7 Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Mon, 14 Oct 2024 15:04:35 -0400 Subject: [PATCH 3/5] remove sync on set itself Signed-off-by: Yuchen Liang --- optd-core/src/cascades/memo.rs | 4 +- optd-core/src/cascades/memo/disjoint_group.rs | 1 + .../set.rs} | 79 +++++++------------ 3 files changed, 33 insertions(+), 51 deletions(-) create mode 100644 optd-core/src/cascades/memo/disjoint_group.rs rename optd-core/src/cascades/memo/{disjoint_set.rs => disjoint_group/set.rs} (68%) diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index 5def21ce..62ad52b7 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -1,4 +1,4 @@ -mod disjoint_set; +mod disjoint_group; use std::{ collections::{hash_map::Entry, HashMap, HashSet}, @@ -6,7 +6,7 @@ use std::{ }; use anyhow::{bail, Result}; -use disjoint_set::{DisjointSet, UnionFind}; +use disjoint_group::set::{DisjointSet, UnionFind}; use itertools::Itertools; use std::any::Any; diff --git a/optd-core/src/cascades/memo/disjoint_group.rs b/optd-core/src/cascades/memo/disjoint_group.rs new file mode 100644 index 00000000..d7e52559 --- /dev/null +++ b/optd-core/src/cascades/memo/disjoint_group.rs @@ -0,0 +1 @@ +pub mod set; diff --git a/optd-core/src/cascades/memo/disjoint_set.rs b/optd-core/src/cascades/memo/disjoint_group/set.rs similarity index 68% rename from optd-core/src/cascades/memo/disjoint_set.rs rename to optd-core/src/cascades/memo/disjoint_group/set.rs index 7989fc3d..35cc6b81 100644 --- a/optd-core/src/cascades/memo/disjoint_set.rs +++ b/optd-core/src/cascades/memo/disjoint_group/set.rs @@ -1,10 +1,4 @@ -use std::{ - collections::HashMap, - fmt::Debug, - hash::Hash, - ops::{Deref, DerefMut}, - sync::{atomic::AtomicUsize, Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, -}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; /// A data structure for efficiently maintaining disjoint sets of `T`. pub struct DisjointSet { @@ -17,9 +11,9 @@ pub struct DisjointSet { /// /// Alternatively, we could do no path compression at `find`, /// and only do path compression when we were doing union. - node_parents: Arc>>, + node_parents: HashMap, /// Number of disjoint sets. - num_sets: AtomicUsize, + num_sets: usize, } pub trait UnionFind @@ -31,11 +25,11 @@ where /// or `None` if one the node is not present. /// /// The smaller representative is selected as the new representative. - fn union(&self, a: &T, b: &T) -> Option<[T; 2]>; + fn union(&mut self, a: &T, b: &T) -> Option<[T; 2]>; /// Gets the representative node of the set that `node` is in. /// Path compression is performed while finding the representative. - fn find_path_compress(&self, node: &T) -> Option; + fn find_path_compress(&mut self, node: &T) -> Option; /// Gets the representative node of the set that `node` is in. fn find(&self, node: &T) -> Option; @@ -47,55 +41,48 @@ where { pub fn new() -> Self { DisjointSet { - node_parents: Arc::new(RwLock::new(HashMap::new())), - num_sets: AtomicUsize::new(0), + node_parents: HashMap::new(), + num_sets: 0, } } pub fn size(&self) -> usize { - let g = self.node_parents.read().unwrap(); - g.len() + self.node_parents.len() } pub fn num_sets(&self) -> usize { - self.num_sets.load(std::sync::atomic::Ordering::Relaxed) + self.num_sets } pub fn add(&mut self, node: T) { use std::collections::hash_map::Entry; - let mut g = self.node_parents.write().unwrap(); - if let Entry::Vacant(entry) = g.entry(node) { + if let Entry::Vacant(entry) = self.node_parents.entry(node) { entry.insert(node); - drop(g); - self.num_sets - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.num_sets += 1; } } - fn get_parent(g: &impl Deref>, node: &T) -> Option { - g.get(node).copied() + fn get_parent(&self, node: &T) -> Option { + self.node_parents.get(node).copied() } - fn set_parent(g: &mut impl DerefMut>, node: T, parent: T) { - g.insert(node, parent); + fn set_parent(&mut self, node: T, parent: T) { + self.node_parents.insert(node, parent); } /// Recursively find the parent of the `node` until reaching the representative of the set. /// A node is the representative if the its parent is the node itself. /// /// We utilize "path compression" to shorten the height of parent forest. - fn find_path_compress_inner( - g: &mut RwLockWriteGuard<'_, HashMap>, - node: &T, - ) -> Option { - let mut parent = Self::get_parent(g, node)?; + fn find_path_compress_inner(&mut self, node: &T) -> Option { + let mut parent = self.get_parent(node)?; if *node != parent { - parent = Self::find_path_compress_inner(g, &parent)?; + parent = self.find_path_compress_inner(&parent)?; // Path compression. - Self::set_parent(g, *node, parent); + self.set_parent(*node, parent); } Some(parent) @@ -103,11 +90,11 @@ where /// Recursively find the parent of the `node` until reaching the representative of the set. /// A node is the representative if the its parent is the node itself. - fn find_inner(g: &RwLockReadGuard<'_, HashMap>, node: &T) -> Option { - let mut parent = Self::get_parent(g, node)?; + fn find_inner(&self, node: &T) -> Option { + let mut parent = self.get_parent(node)?; if *node != parent { - parent = Self::find_inner(g, &parent)?; + parent = self.find_inner(&parent)?; } Some(parent) @@ -118,7 +105,7 @@ impl UnionFind for DisjointSet where T: Ord + Hash + Copy + Debug, { - fn union(&self, a: &T, b: &T) -> Option<[T; 2]> { + fn union(&mut self, a: &T, b: &T) -> Option<[T; 2]> { use std::cmp::Ordering; // Gets the represenatives for set containing `a`. @@ -127,20 +114,16 @@ where // Gets the represenatives for set containing `b`. let b_rep = self.find_path_compress(&b)?; - let mut g = self.node_parents.write().unwrap(); - // Node with smaller value becomes the representative. let res = match a_rep.cmp(&b_rep) { Ordering::Less => { - Self::set_parent(&mut g, b_rep, a_rep); - self.num_sets - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.set_parent(b_rep, a_rep); + self.num_sets -= 1; [a_rep, b_rep] } Ordering::Greater => { - Self::set_parent(&mut g, a_rep, b_rep); - self.num_sets - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.set_parent(a_rep, b_rep); + self.num_sets -= 1; [b_rep, a_rep] } Ordering::Equal => [a_rep, b_rep], @@ -149,15 +132,13 @@ where } /// See [`Self::find_inner`] for implementation detail. - fn find_path_compress(&self, node: &T) -> Option { - let mut g = self.node_parents.write().unwrap(); - Self::find_path_compress_inner(&mut g, node) + fn find_path_compress(&mut self, node: &T) -> Option { + self.find_path_compress_inner(node) } /// See [`Self::find_inner`] for implementation detail. fn find(&self, node: &T) -> Option { - let g = self.node_parents.read().unwrap(); - Self::find_inner(&g, node) + self.find_inner(node) } } From 4129dff742785c5a6bcde21451f3e103ebb32efc Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Mon, 14 Oct 2024 16:32:22 -0400 Subject: [PATCH 4/5] implement DisjointGroupMap Signed-off-by: Yuchen Liang --- optd-core/src/cascades/memo.rs | 9 +- optd-core/src/cascades/memo/disjoint_group.rs | 110 ++++++++++++++++++ .../src/cascades/memo/disjoint_group/set.rs | 13 ++- 3 files changed, 129 insertions(+), 3 deletions(-) diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index 62ad52b7..c4bbfb0a 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -6,7 +6,10 @@ use std::{ }; use anyhow::{bail, Result}; -use disjoint_group::set::{DisjointSet, UnionFind}; +use disjoint_group::{ + set::{DisjointSet, UnionFind}, + DisjointGroupMap, +}; use itertools::Itertools; use std::any::Any; @@ -195,6 +198,7 @@ impl Memo { props } + // TODO(yuchen): make internal to disjoint group fn clear_exprs_in_group(&mut self, group_id: GroupId) { self.groups.remove(&group_id); } @@ -202,6 +206,7 @@ impl Memo { /// If group_id exists, it adds expr_id to the existing group /// Otherwise, it creates a new group of that group_id and insert expr_id into the new group fn add_expr_to_group(&mut self, expr_id: ExprId, group_id: GroupId, memo_node: RelMemoNode) { + // TODO(yuchen): use entry API if let Entry::Occupied(mut entry) = self.groups.entry(group_id) { let group = entry.get_mut(); group.group_exprs.insert(expr_id); @@ -214,6 +219,7 @@ impl Memo { }; group.group_exprs.insert(expr_id); self.disjoint_group_ids.add(group_id); + // TODO(yuchen): use insert self.groups.insert(group_id, group); } @@ -228,6 +234,7 @@ impl Memo { ) -> bool { let replace_group_id = self.get_reduced_group_id(replace_group_id); + // TODO(yuchen): use disjoint group entry API if let Entry::Occupied(mut entry) = self.groups.entry(replace_group_id) { let group = entry.get_mut(); if !group.group_exprs.contains(&expr_id) { diff --git a/optd-core/src/cascades/memo/disjoint_group.rs b/optd-core/src/cascades/memo/disjoint_group.rs index d7e52559..57c22567 100644 --- a/optd-core/src/cascades/memo/disjoint_group.rs +++ b/optd-core/src/cascades/memo/disjoint_group.rs @@ -1 +1,111 @@ +use std::{ + collections::{hash_map, HashMap}, + ops::Index, +}; + +use set::{DisjointSet, UnionFind}; + +use crate::cascades::GroupId; + +use super::Group; + pub mod set; + +const MISMATCH_ERROR: &str = "`groups` and `id_map` report unmatched group membership"; + +pub(crate) struct DisjointGroupMap { + id_map: DisjointSet, + groups: HashMap, +} + +impl DisjointGroupMap { + /// Creates a new disjoint group instance. + pub fn new() -> Self { + DisjointGroupMap { + id_map: DisjointSet::new(), + groups: HashMap::new(), + } + } + + pub fn get(&self, id: &GroupId) -> Option<&Group> { + self.id_map + .find(id) + .map(|rep| self.groups.get(&rep).expect(MISMATCH_ERROR)) + } + + pub fn get_mut(&mut self, id: &GroupId) -> Option<&mut Group> { + self.id_map + .find(id) + .map(|rep| self.groups.get_mut(&rep).expect(MISMATCH_ERROR)) + } + + unsafe fn insert_new(&mut self, id: GroupId, group: Group) { + self.id_map.add(id); + self.groups.insert(id, group); + } + + pub fn entry(&mut self, id: GroupId) -> GroupEntry<'_> { + use hash_map::Entry::*; + let rep = self.id_map.find(&id).unwrap_or(id); + let id_entry = self.id_map.entry(rep); + let group_entry = self.groups.entry(rep); + match (id_entry, group_entry) { + (Occupied(_), Occupied(inner)) => GroupEntry::Occupied(OccupiedGroupEntry { inner }), + (Vacant(id), Vacant(inner)) => GroupEntry::Vacant(VacantGroupEntry { id, inner }), + _ => unreachable!("{MISMATCH_ERROR}"), + } + } +} + +pub enum GroupEntry<'a> { + Occupied(OccupiedGroupEntry<'a>), + Vacant(VacantGroupEntry<'a>), +} + +pub struct OccupiedGroupEntry<'a> { + inner: hash_map::OccupiedEntry<'a, GroupId, Group>, +} + +pub struct VacantGroupEntry<'a> { + id: hash_map::VacantEntry<'a, GroupId, GroupId>, + inner: hash_map::VacantEntry<'a, GroupId, Group>, +} + +impl<'a> OccupiedGroupEntry<'a> { + pub fn id(&self) -> &GroupId { + self.inner.key() + } + + pub fn get(&self) -> &Group { + self.inner.get() + } + + pub fn get_mut(&mut self) -> &mut Group { + self.inner.get_mut() + } +} + +impl<'a> VacantGroupEntry<'a> { + pub fn id(&self) -> &GroupId { + self.inner.key() + } + + pub fn insert(self, group: Group) -> &'a mut Group { + let id = *self.id(); + self.id.insert(id); + self.inner.insert(group) + } +} + +impl Index<&GroupId> for DisjointGroupMap { + type Output = Group; + + fn index(&self, index: &GroupId) -> &Self::Output { + let rep = self + .id_map + .find(index) + .expect("no group found for group id"); + + self.groups.get(&rep).expect(MISMATCH_ERROR) + } +} diff --git a/optd-core/src/cascades/memo/disjoint_group/set.rs b/optd-core/src/cascades/memo/disjoint_group/set.rs index 35cc6b81..c343e59e 100644 --- a/optd-core/src/cascades/memo/disjoint_group/set.rs +++ b/optd-core/src/cascades/memo/disjoint_group/set.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, fmt::Debug, hash::Hash}; +use std::{ + collections::{hash_map, HashMap}, + fmt::Debug, + hash::Hash, +}; /// A data structure for efficiently maintaining disjoint sets of `T`. pub struct DisjointSet { @@ -37,7 +41,7 @@ where impl DisjointSet where - T: Ord + Hash + Copy, + T: Ord + Hash + Copy + Debug, { pub fn new() -> Self { DisjointSet { @@ -99,6 +103,11 @@ where Some(parent) } + + /// Gets the given node corresponding entry in `node_parents` map for in-place manipulation. + pub(super) fn entry(&mut self, node: T) -> hash_map::Entry<'_, T, T> { + self.node_parents.entry(node) + } } impl UnionFind for DisjointSet From 0890c3a1ce08ccfc34a983e33c3479a7fec9a5d8 Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Tue, 15 Oct 2024 00:32:31 -0400 Subject: [PATCH 5/5] add DisjointGroup::merge Signed-off-by: Yuchen Liang --- optd-core/src/cascades/memo.rs | 2 +- optd-core/src/cascades/memo/disjoint_group.rs | 32 +++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index c4bbfb0a..889486e8 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -118,7 +118,7 @@ impl Memo { // Remove all expressions from group other (so we don't accidentally access it) self.clear_exprs_in_group(other); - group_b + rep } fn get_group_id_of_expr_id(&self, expr_id: ExprId) -> GroupId { diff --git a/optd-core/src/cascades/memo/disjoint_group.rs b/optd-core/src/cascades/memo/disjoint_group.rs index 57c22567..570ceb5b 100644 --- a/optd-core/src/cascades/memo/disjoint_group.rs +++ b/optd-core/src/cascades/memo/disjoint_group.rs @@ -3,11 +3,15 @@ use std::{ ops::Index, }; +use itertools::Itertools; use set::{DisjointSet, UnionFind}; -use crate::cascades::GroupId; +use crate::{ + cascades::{optimizer::ExprId, GroupId}, + rel_node::RelNodeTyp, +}; -use super::Group; +use super::{Group, RelMemoNodeRef}; pub mod set; @@ -44,6 +48,30 @@ impl DisjointGroupMap { self.groups.insert(id, group); } + /// Merge the group `a` and group `b`. Returns the merged representative group id. + pub fn merge(&mut self, a: GroupId, b: GroupId) -> GroupId { + if a == b { + return a; + } + + let [rep, other] = self.id_map.union(&a, &b).unwrap(); + + // Drain all expressions from group other, copy to its representative + let other_exprs = self.drain_all_exprs_in(&other); + let rep_group = self.get_mut(&rep).expect("group not found"); + for expr_id in other_exprs { + rep_group.group_exprs.insert(expr_id); + } + + rep + } + + /// Drain all expressions from the group, returns an iterator of the expressions. + fn drain_all_exprs_in(&mut self, id: &GroupId) -> impl Iterator { + let group = self.groups.remove(&id).expect("group not found"); + group.group_exprs.into_iter().sorted() + } + pub fn entry(&mut self, id: GroupId) -> GroupEntry<'_> { use hash_map::Entry::*; let rep = self.id_map.find(&id).unwrap_or(id);