Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] core: group reduction using union find #187

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 43 additions & 75 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
mod disjoint_group;

use std::{
collections::{hash_map::Entry, HashMap, HashSet},
fmt::Display,
sync::Arc,
};

use anyhow::{bail, Result};
use disjoint_group::{
set::{DisjointSet, UnionFind},
DisjointGroupMap,
};
use itertools::Itertools;
use std::any::Any;

Expand Down Expand Up @@ -57,28 +62,15 @@ pub(crate) struct Group {
pub(crate) properties: Arc<[Box<dyn Any + Send + Sync + 'static>]>,
}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)]
struct ReducedGroupId(usize);

impl ReducedGroupId {
pub fn as_group_id(self) -> GroupId {
GroupId(self.0)
}
}

impl Display for ReducedGroupId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

pub struct Memo<T: RelNodeTyp> {
expr_id_to_group_id: HashMap<ExprId, GroupId>,
expr_id_to_expr_node: HashMap<ExprId, RelMemoNodeRef<T>>,
expr_node_to_expr_id: HashMap<RelMemoNode<T>, ExprId>,
groups: HashMap<ReducedGroupId, Group>,
/// Stores the mapping from "representative" group id to group.
groups: HashMap<GroupId, Group>,
group_expr_counter: usize,
merged_groups: HashMap<GroupId, GroupId>,
/// Keeps track of disjoint sets of group ids.
disjoint_group_ids: DisjointSet<GroupId>,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
}

Expand All @@ -90,16 +82,16 @@ impl<T: RelNodeTyp> Memo<T> {
expr_node_to_expr_id: HashMap::new(),
groups: HashMap::new(),
group_expr_counter: 0,
merged_groups: HashMap::new(),
disjoint_group_ids: DisjointSet::new(),
property_builders,
}
}

/// Get the next group id. Group id and expr id shares the same counter, so as to make it easier to debug...
fn next_group_id(&mut self) -> ReducedGroupId {
let id = self.group_expr_counter;
fn next_group_id(&mut self) -> GroupId {
let id = GroupId(self.group_expr_counter);
self.group_expr_counter += 1;
ReducedGroupId(id)
id
}

/// Get the next expr id. Group id and expr id shares the same counter, so as to make it easier to debug...
Expand All @@ -109,47 +101,32 @@ impl<T: RelNodeTyp> Memo<T> {
ExprId(id)
}

fn merge_group_inner(
&mut self,
group_a: ReducedGroupId,
group_b: ReducedGroupId,
) -> ReducedGroupId {
pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId {
if group_a == group_b {
return group_a;
}

// Copy all expressions from group a to group b
let group_a_exprs = self.get_all_exprs_in_group(group_a.as_group_id());
for expr_id in group_a_exprs {
let [rep, other] = self.disjoint_group_ids.union(&group_a, &group_b).unwrap();

// Copy all expressions from group other to its representative
let other_exprs = self.get_all_exprs_in_group(other);
for expr_id in other_exprs {
let expr_node = self.expr_id_to_expr_node.get(&expr_id).unwrap();
self.add_expr_to_group(expr_id, group_b, expr_node.as_ref().clone());
self.add_expr_to_group(expr_id, rep, expr_node.as_ref().clone());
}

self.merged_groups
.insert(group_a.as_group_id(), group_b.as_group_id());

// Remove all expressions from group a (so we don't accidentally access it)
self.clear_exprs_in_group(group_a);
// Remove all expressions from group other (so we don't accidentally access it)
self.clear_exprs_in_group(other);

group_b
}

pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId {
let group_a_reduced = self.get_reduced_group_id(group_a);
let group_b_reduced = self.get_reduced_group_id(group_b);
self.merge_group_inner(group_a_reduced, group_b_reduced)
.as_group_id()
rep
}

fn get_group_id_of_expr_id(&self, expr_id: ExprId) -> GroupId {
self.expr_id_to_group_id[&expr_id]
}

fn get_reduced_group_id(&self, mut group_id: GroupId) -> ReducedGroupId {
while let Some(next_group_id) = self.merged_groups.get(&group_id) {
group_id = *next_group_id;
}
ReducedGroupId(group_id.0)
fn get_reduced_group_id(&self, group_id: GroupId) -> GroupId {
self.disjoint_group_ids.find(&group_id).unwrap()
}

/// Add or get an expression into the memo, returns the group id and the expr id. If `GroupId` is `None`,
Expand All @@ -164,11 +141,9 @@ impl<T: RelNodeTyp> Memo<T> {
self.merge_group(grp_a, grp_b);
};

let (group_id, expr_id) = self.add_new_group_expr_inner(
rel_node,
add_to_group_id.map(|x| self.get_reduced_group_id(x)),
);
(group_id.as_group_id(), expr_id)
let add_to_group_id = add_to_group_id.map(|x| self.get_reduced_group_id(x));
let (group_id, expr_id) = self.add_new_group_expr_inner(rel_node, add_to_group_id);
(group_id, expr_id)
}

pub fn get_expr_info(&self, rel_node: RelNodeRef<T>) -> (GroupId, ExprId) {
Expand Down Expand Up @@ -223,18 +198,15 @@ impl<T: RelNodeTyp> Memo<T> {
props
}

fn clear_exprs_in_group(&mut self, group_id: ReducedGroupId) {
// TODO(yuchen): make internal to disjoint group
fn clear_exprs_in_group(&mut self, group_id: GroupId) {
self.groups.remove(&group_id);
}

/// If group_id exists, it adds expr_id to the existing group
/// Otherwise, it creates a new group of that group_id and insert expr_id into the new group
fn add_expr_to_group(
&mut self,
expr_id: ExprId,
group_id: ReducedGroupId,
memo_node: RelMemoNode<T>,
) {
fn add_expr_to_group(&mut self, expr_id: ExprId, group_id: GroupId, memo_node: RelMemoNode<T>) {
// TODO(yuchen): use entry API
if let Entry::Occupied(mut entry) = self.groups.entry(group_id) {
let group = entry.get_mut();
group.group_exprs.insert(expr_id);
Expand All @@ -246,6 +218,8 @@ impl<T: RelNodeTyp> Memo<T> {
properties: self.infer_properties(memo_node).into(),
};
group.group_exprs.insert(expr_id);
self.disjoint_group_ids.add(group_id);
// TODO(yuchen): use insert
self.groups.insert(group_id, group);
}

Expand All @@ -260,6 +234,7 @@ impl<T: RelNodeTyp> Memo<T> {
) -> bool {
let replace_group_id = self.get_reduced_group_id(replace_group_id);

// TODO(yuchen): use disjoint group entry API
if let Entry::Occupied(mut entry) = self.groups.entry(replace_group_id) {
let group = entry.get_mut();
if !group.group_exprs.contains(&expr_id) {
Expand Down Expand Up @@ -296,9 +271,8 @@ impl<T: RelNodeTyp> Memo<T> {
and make sure if it does not do any transformation, it should return an empty vec!");
}
let group_id = self.get_group_id_of_expr_id(new_expr_id);
let group_id = self.get_reduced_group_id(group_id);

self.merge_group_inner(replace_group_id, group_id);
self.merge_group(replace_group_id, group_id);
return false;
}

Expand All @@ -314,8 +288,8 @@ impl<T: RelNodeTyp> Memo<T> {
fn add_new_group_expr_inner(
&mut self,
rel_node: RelNodeRef<T>,
add_to_group_id: Option<ReducedGroupId>,
) -> (ReducedGroupId, ExprId) {
add_to_group_id: Option<GroupId>,
) -> (GroupId, ExprId) {
let children_group_ids = rel_node
.children
.iter()
Expand All @@ -336,7 +310,7 @@ impl<T: RelNodeTyp> Memo<T> {
let group_id = self.get_group_id_of_expr_id(expr_id);
let group_id = self.get_reduced_group_id(group_id);
if let Some(add_to_group_id) = add_to_group_id {
self.merge_group_inner(add_to_group_id, group_id);
self.merge_group(add_to_group_id, group_id);
}
return (group_id, expr_id);
}
Expand All @@ -348,8 +322,7 @@ impl<T: RelNodeTyp> Memo<T> {
};
self.expr_id_to_expr_node
.insert(expr_id, memo_node.clone().into());
self.expr_id_to_group_id
.insert(expr_id, group_id.as_group_id());
self.expr_id_to_group_id.insert(expr_id, group_id);
self.expr_node_to_expr_id.insert(memo_node.clone(), expr_id);
self.add_expr_to_group(expr_id, group_id, memo_node);
(group_id, expr_id)
Expand All @@ -362,7 +335,7 @@ impl<T: RelNodeTyp> Memo<T> {
.expr_id_to_group_id
.get(&expr_id)
.expect("expr not found in group mapping");
self.get_reduced_group_id(*group_id).as_group_id()
self.get_reduced_group_id(*group_id)
}

/// Get the memoized representation of a node.
Expand Down Expand Up @@ -463,12 +436,7 @@ impl<T: RelNodeTyp> Memo<T> {
}

pub fn get_all_group_ids(&self) -> Vec<GroupId> {
let mut ids = self
.groups
.keys()
.copied()
.map(|x| x.as_group_id())
.collect_vec();
let mut ids = self.groups.keys().copied().collect_vec();
ids.sort();
ids
}
Expand Down
139 changes: 139 additions & 0 deletions optd-core/src/cascades/memo/disjoint_group.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::{
collections::{hash_map, HashMap},
ops::Index,
};

use itertools::Itertools;
use set::{DisjointSet, UnionFind};

use crate::{
cascades::{optimizer::ExprId, GroupId},
rel_node::RelNodeTyp,
};

use super::{Group, RelMemoNodeRef};

pub mod set;

const MISMATCH_ERROR: &str = "`groups` and `id_map` report unmatched group membership";

pub(crate) struct DisjointGroupMap {
id_map: DisjointSet<GroupId>,
groups: HashMap<GroupId, Group>,
}

impl DisjointGroupMap {
/// Creates a new disjoint group instance.
pub fn new() -> Self {
DisjointGroupMap {
id_map: DisjointSet::new(),
groups: HashMap::new(),
}
}

pub fn get(&self, id: &GroupId) -> Option<&Group> {
self.id_map
.find(id)
.map(|rep| self.groups.get(&rep).expect(MISMATCH_ERROR))
}

pub fn get_mut(&mut self, id: &GroupId) -> Option<&mut Group> {
self.id_map
.find(id)
.map(|rep| self.groups.get_mut(&rep).expect(MISMATCH_ERROR))
}

unsafe fn insert_new(&mut self, id: GroupId, group: Group) {
self.id_map.add(id);
self.groups.insert(id, group);
}

/// Merge the group `a` and group `b`. Returns the merged representative group id.
pub fn merge(&mut self, a: GroupId, b: GroupId) -> GroupId {
if a == b {
return a;
}

let [rep, other] = self.id_map.union(&a, &b).unwrap();

// Drain all expressions from group other, copy to its representative
let other_exprs = self.drain_all_exprs_in(&other);
let rep_group = self.get_mut(&rep).expect("group not found");
for expr_id in other_exprs {
rep_group.group_exprs.insert(expr_id);
}

rep
}

/// Drain all expressions from the group, returns an iterator of the expressions.
fn drain_all_exprs_in(&mut self, id: &GroupId) -> impl Iterator<Item = ExprId> {
let group = self.groups.remove(&id).expect("group not found");
group.group_exprs.into_iter().sorted()
}

pub fn entry(&mut self, id: GroupId) -> GroupEntry<'_> {
use hash_map::Entry::*;
let rep = self.id_map.find(&id).unwrap_or(id);
let id_entry = self.id_map.entry(rep);
let group_entry = self.groups.entry(rep);
match (id_entry, group_entry) {
(Occupied(_), Occupied(inner)) => GroupEntry::Occupied(OccupiedGroupEntry { inner }),
(Vacant(id), Vacant(inner)) => GroupEntry::Vacant(VacantGroupEntry { id, inner }),
_ => unreachable!("{MISMATCH_ERROR}"),
}
}
}

pub enum GroupEntry<'a> {
Occupied(OccupiedGroupEntry<'a>),
Vacant(VacantGroupEntry<'a>),
}

pub struct OccupiedGroupEntry<'a> {
inner: hash_map::OccupiedEntry<'a, GroupId, Group>,
}

pub struct VacantGroupEntry<'a> {
id: hash_map::VacantEntry<'a, GroupId, GroupId>,
inner: hash_map::VacantEntry<'a, GroupId, Group>,
}

impl<'a> OccupiedGroupEntry<'a> {
pub fn id(&self) -> &GroupId {
self.inner.key()
}

pub fn get(&self) -> &Group {
self.inner.get()
}

pub fn get_mut(&mut self) -> &mut Group {
self.inner.get_mut()
}
}

impl<'a> VacantGroupEntry<'a> {
pub fn id(&self) -> &GroupId {
self.inner.key()
}

pub fn insert(self, group: Group) -> &'a mut Group {
let id = *self.id();
self.id.insert(id);
self.inner.insert(group)
}
}

impl Index<&GroupId> for DisjointGroupMap {
type Output = Group;

fn index(&self, index: &GroupId) -> &Self::Output {
let rep = self
.id_map
.find(index)
.expect("no group found for group id");

self.groups.get(&rep).expect(MISMATCH_ERROR)
}
}
Loading
Loading