Skip to content

Commit

Permalink
add DisjointGroup::merge
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Oct 15, 2024
1 parent 4129dff commit 0890c3a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl<T: RelNodeTyp> Memo<T> {
// Remove all expressions from group other (so we don't accidentally access it)
self.clear_exprs_in_group(other);

group_b
rep
}

fn get_group_id_of_expr_id(&self, expr_id: ExprId) -> GroupId {
Expand Down
32 changes: 30 additions & 2 deletions optd-core/src/cascades/memo/disjoint_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ use std::{
ops::Index,
};

use itertools::Itertools;
use set::{DisjointSet, UnionFind};

use crate::cascades::GroupId;
use crate::{
cascades::{optimizer::ExprId, GroupId},
rel_node::RelNodeTyp,
};

use super::Group;
use super::{Group, RelMemoNodeRef};

pub mod set;

Expand Down Expand Up @@ -44,6 +48,30 @@ impl DisjointGroupMap {
self.groups.insert(id, group);
}

/// Merge the group `a` and group `b`. Returns the merged representative group id.
pub fn merge(&mut self, a: GroupId, b: GroupId) -> GroupId {
if a == b {
return a;
}

let [rep, other] = self.id_map.union(&a, &b).unwrap();

// Drain all expressions from group other, copy to its representative
let other_exprs = self.drain_all_exprs_in(&other);
let rep_group = self.get_mut(&rep).expect("group not found");
for expr_id in other_exprs {
rep_group.group_exprs.insert(expr_id);
}

rep
}

/// Drain all expressions from the group, returns an iterator of the expressions.
fn drain_all_exprs_in(&mut self, id: &GroupId) -> impl Iterator<Item = ExprId> {
let group = self.groups.remove(&id).expect("group not found");
group.group_exprs.into_iter().sorted()
}

pub fn entry(&mut self, id: GroupId) -> GroupEntry<'_> {
use hash_map::Entry::*;
let rep = self.id_map.find(&id).unwrap_or(id);
Expand Down

0 comments on commit 0890c3a

Please sign in to comment.