Skip to content

Commit

Permalink
add union find and group merging support
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Dec 1, 2024
1 parent 0a0af6d commit ce1b93d
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 26 deletions.
67 changes: 67 additions & 0 deletions optd-mvp/DESIGN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Duplicate Elimination Memo Table

Note that most of the details are in `src/memo/persistent/implementation.rs`.

For this document, we are assuming that the memo table is backed by a database / ORM. A lot of these
problems would likely not be an issue if everything was in memory.

## Group Merging

During logical exploration, there will be rules that create cycles between groups. The easy solution
for this is to immediately merge two groups together when the engine determines that adding an
expression would result in a duplicate expression from another group.

However, if we want to support parallel exploration, this could be prone to high contention. By
definition, merging group G1 into group G2 would mean that _every expression_ that has a child of
group G1 with would need to be rewritten to point to group G2 instead.

This is unacceptable in a parallel setting, as that would mean every single task that gets affected
would need to either wait for the rewrites to happen before resuming work, or need to abort their
work because data has changed underneath them.

So immediate / eager group merging is not a great idea for parallel exploration. However, if we do
not ever merge two groups that are identical, we are subject to doing duplicate work for every
duplicate expression in the memo table during physical optimization.

Instead of merging groups together immediately, we can instead maintain an auxiliary data structure
that records the groups that _eventually_ need to get merged, and "lazily" merge those groups
together once every group has finished exploration.

## Union-Find Group Sets

We use the well-known Union-Find algorithm and corresponding data structure as the auxiliary data
structure that tracks the to-be-merged groups.

Union-Find supports `Union` and `Find` operations, where `Union` merges sets and `Find` searches for
a "canonical" or "root" element that is shared between all elements in a given set.

For more information about Union-Find, see these
[15-451 lecture notes](https://www.cs.cmu.edu/~15451-f24/lectures/lecture08-union-find.pdf).

Here, we make the elements the groups themselves (really the Group IDs), which allows us to merge
group sets together and also determine a "root group" that all groups in a set can agree on.

When every group in a group set has finished exploration, we can safely begin to merge them
together by moving all expressions from every group in the group set into a single large group.
Other than making sure that any reference to an old group in the group set points to this new large
group, exploration of all groups are done and physical optimization can start.

RFC: Do we need to support incremental search?

Note that since we are now waiting for exploration of all groups to finish, this algorithm is much
closer to the Volcano framework than the Cascades' incremental search. However, since we eventually
will want to store trails / breadcrumbs of decisions made to skip work in the future, and since we
essentially have unlimited space due to the memo table being backed by a DBMS, this is not as much
of a problem.

## Duplicate Detection

TODO explain the fingerprinting algorithm and how it relates to group merging

When taking the fingerprint of an expression, the child groups of an expression need to be root groups. If they are not, we need to try again.
Assuming that all children are root groups, the fingerprint we make for any expression that fulfills that is valid and can be looked up for duplicates.
In order to maintain that correctness, on a merge of two sets, the smaller one requires that a new fingerprint be generated for every expression that has a group in that smaller set.
For example, let's say we need to merge { 1, 2 } (root group 1) with { 3, 4, 5, 6, 7, 8 } (root group 3). We need to find every single expression that has a child group of 1 or 2 and we need to generate a new fingerprint for each where the child groups have been "rewritten" to 3

TODO this is incredibly expensive, but is potentially easily parallelizable?

File renamed without changes.
2 changes: 1 addition & 1 deletion optd-mvp/src/entities/cascades_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*;
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub status: i8,
pub winner: Option<i32>,
pub cost: Option<i64>,
pub is_optimized: bool,
pub parent_id: Option<i32>,
}

Expand Down
8 changes: 8 additions & 0 deletions optd-mvp/src/memo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PhysicalExpressionId(pub i32);

/// A status enum representing the different states a group can be during query optimization.
#[repr(u8)]
pub enum GroupStatus {
InProgress = 0,
Explored = 1,
Optimized = 2,
}

/// The different kinds of errors that might occur while running operations on a memo table.
#[derive(Error, Debug)]
pub enum MemoError {
Expand Down
142 changes: 125 additions & 17 deletions optd-mvp/src/memo/persistent/implementation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::PersistentMemo;
use crate::{
entities::*,
expression::{LogicalExpression, PhysicalExpression},
memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId},
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
OptimizerResult, DATABASE_URL,
};
use sea_orm::{
Expand Down Expand Up @@ -66,6 +66,40 @@ impl PersistentMemo {
.ok_or(MemoError::UnknownGroup(group_id))?)
}

/// Retrieves the root / canonical group ID of the given group ID.
///
/// The groups form a union find / disjoint set parent pointer forest, where group merging
/// causes two trees to merge.
///
/// This function uses the path compression optimization, which amortizes the cost to a single
/// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult<GroupId> {
let mut curr_group = self.get_group(group_id).await?;

// Traverse up the path and find the root group, keeping track of groups we have visited.
let mut path = vec![];
loop {
let Some(parent_id) = curr_group.parent_id else {
break;
};

let next_group = self.get_group(GroupId(parent_id)).await?;
path.push(curr_group);
curr_group = next_group;
}

let root_id = GroupId(curr_group.id);

// Path Compression Optimization:
// For every group along the path that we walked, set their parent id pointer to the root.
// This allows for an amortized O(1) cost for `get_root_group`.
for group in path {
self.update_group_parent(GroupId(group.id), root_id).await?;
}

Ok(root_id)
}

/// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
///
/// If the physical expression does not exist, returns a
Expand Down Expand Up @@ -146,6 +180,32 @@ impl PersistentMemo {
Ok(children)
}

/// Updates / replaces a group's status. Returns the previous group status.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
pub async fn update_group_status(
&self,
group_id: GroupId,
status: GroupStatus,
) -> OptimizerResult<GroupStatus> {
// First retrieve the group record.
let mut group = self.get_group(group_id).await?.into_active_model();

// Update the group's status.
let old_status = group.status;
group.status = Set(status as u8 as i8);
group.update(&self.db).await?;

let old_status = match old_status.unwrap() {
0 => GroupStatus::InProgress,
1 => GroupStatus::Explored,
2 => GroupStatus::Optimized,
_ => panic!("encountered an invalid group status"),
};

Ok(old_status)
}

/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
/// winner's physical expression ID.
///
Expand All @@ -167,8 +227,32 @@ impl PersistentMemo {
group.update(&self.db).await?;

// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
let old = old_id.unwrap().map(PhysicalExpressionId);
Ok(old)
let old_id = old_id.unwrap().map(PhysicalExpressionId);
Ok(old_id)
}

/// Updates / replaces a group's parent group. Optionally returns the previous parent.
///
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
pub async fn update_group_parent(
&self,
group_id: GroupId,
parent_id: GroupId,
) -> OptimizerResult<Option<GroupId>> {
// First retrieve the group record.
let mut group = self.get_group(group_id).await?.into_active_model();

// Check that the parent group exists.
let _ = self.get_group(parent_id).await?;

// Update the group to point to the new parent.
let old_parent = group.parent_id;
group.parent_id = Set(Some(parent_id.0));
group.update(&self.db).await?;

// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
let old_parent = old_parent.unwrap().map(GroupId);
Ok(old_parent)
}

/// Adds a logical expression to an existing group via its ID.
Expand All @@ -192,10 +276,10 @@ impl PersistentMemo {
group_id: GroupId,
logical_expression: LogicalExpression,
children: &[GroupId],
) -> OptimizerResult<Result<LogicalExpressionId, LogicalExpressionId>> {
) -> OptimizerResult<Result<LogicalExpressionId, (GroupId, LogicalExpressionId)>> {
// Check if the expression already exists anywhere in the memo table.
if let Some(existing_id) = self
.is_duplicate_logical_expression(&logical_expression)
.is_duplicate_logical_expression(&logical_expression, children)
.await?
{
return Ok(Err(existing_id));
Expand Down Expand Up @@ -227,7 +311,15 @@ impl PersistentMemo {
// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let kind = new_expr.kind();
let hash = new_expr.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let hash = new_expr.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
id: NotSet,
Expand Down Expand Up @@ -285,8 +377,8 @@ impl PersistentMemo {
/// In order to prevent a large amount of duplicate work, the memo table must support duplicate
/// expression detection.
///
/// Returns `Some(expression_id)` if the memo table detects that the expression already exists,
/// and `None` otherwise.
/// Returns `Some((group_id, expression_id))` if the memo table detects that the expression
/// already exists, and `None` otherwise.
///
/// This function assumes that the child groups of the expression are currently roots of their
/// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input
Expand All @@ -296,13 +388,22 @@ impl PersistentMemo {
pub async fn is_duplicate_logical_expression(
&self,
logical_expression: &LogicalExpression,
) -> OptimizerResult<Option<LogicalExpressionId>> {
children: &[GroupId],
) -> OptimizerResult<Option<(GroupId, LogicalExpressionId)>> {
let model: logical_expression::Model = logical_expression.clone().into();

// Lookup all expressions that have the same fingerprint and kind. There may be false
// positives, but we will check for those next.
let kind = model.kind;
let fingerprint = logical_expression.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites);

// Filter first by the fingerprint, and then the kind.
// FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
Expand All @@ -323,11 +424,11 @@ impl PersistentMemo {
let mut match_id = None;
for potential_match in potential_matches {
let expr_id = LogicalExpressionId(potential_match.logical_expression_id);
let (_, expr) = self.get_logical_expression(expr_id).await?;
let (group_id, expr) = self.get_logical_expression(expr_id).await?;

// Check for an exact match.
if &expr == logical_expression {
match_id = Some(expr_id);
match_id = Some((group_id, expr_id));

// There should be at most one duplicate expression, so we can break here.
break;
Expand Down Expand Up @@ -359,18 +460,17 @@ impl PersistentMemo {
) -> OptimizerResult<Result<(GroupId, LogicalExpressionId), (GroupId, LogicalExpressionId)>>
{
// Check if the expression already exists in the memo table.
if let Some(existing_id) = self
.is_duplicate_logical_expression(&logical_expression)
if let Some((group_id, existing_id)) = self
.is_duplicate_logical_expression(&logical_expression, children)
.await?
{
let (group_id, _expr) = self.get_logical_expression(existing_id).await?;
return Ok(Err((group_id, existing_id)));
}

// The expression does not exist yet, so we need to create a new group and new expression.
let group = cascades_group::ActiveModel {
winner: Set(None),
is_optimized: Set(false),
status: Set(0), // `GroupStatus::InProgress` status.
..Default::default()
};

Expand Down Expand Up @@ -401,7 +501,15 @@ impl PersistentMemo {
// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let kind = new_expr.kind();
let hash = new_expr.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let hash = new_expr.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
id: NotSet,
Expand Down
Loading

0 comments on commit ce1b93d

Please sign in to comment.