Skip to content

Commit

Permalink
add union find and group merging support
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Nov 30, 2024
1 parent 0a0af6d commit 9d9db84
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 26 deletions.
Empty file added optd-mvp/DESIGN.md
Empty file.
File renamed without changes.
2 changes: 1 addition & 1 deletion optd-mvp/src/entities/cascades_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*;
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub status: i8,
pub winner: Option<i32>,
pub cost: Option<i64>,
pub is_optimized: bool,
pub parent_id: Option<i32>,
}

Expand Down
8 changes: 8 additions & 0 deletions optd-mvp/src/memo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PhysicalExpressionId(pub i32);

/// A status enum representing the different states a group can be during query optimization.
#[repr(u8)]
pub enum GroupStatus {
InProgress = 0,
Explored = 1,
Optimized = 2,
}

/// The different kinds of errors that might occur while running operations on a memo table.
#[derive(Error, Debug)]
pub enum MemoError {
Expand Down
142 changes: 125 additions & 17 deletions optd-mvp/src/memo/persistent/implementation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::PersistentMemo;
use crate::{
entities::*,
expression::{LogicalExpression, PhysicalExpression},
memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId},
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
OptimizerResult, DATABASE_URL,
};
use sea_orm::{
Expand Down Expand Up @@ -66,6 +66,40 @@ impl PersistentMemo {
.ok_or(MemoError::UnknownGroup(group_id))?)
}

/// Retrieves the root / canonical group ID of the given group ID.
///
/// The groups form a union find / disjoint set parent pointer forest, where group merging
/// causes two trees to merge.
///
/// This function uses the path compression optimization, which amortizes the cost to a single
/// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult<GroupId> {
let mut curr_group = self.get_group(group_id).await?;

// Traverse up the path and find the root group, keeping track of groups we have visited.
let mut path = vec![];
loop {
let Some(parent_id) = curr_group.parent_id else {
break;
};

let next_group = self.get_group(GroupId(parent_id)).await?;
path.push(curr_group);
curr_group = next_group;
}

let root_id = GroupId(curr_group.id);

// Path Compression Optimization:
// For every group along the path that we walked, set their parent id pointer to the root.
// This allows for an amortized O(1) cost for `get_root_group`.
for group in path {
self.update_group_parent(GroupId(group.id), root_id).await?;
}

Ok(root_id)
}

/// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
///
/// If the physical expression does not exist, returns a
Expand Down Expand Up @@ -146,6 +180,32 @@ impl PersistentMemo {
Ok(children)
}

/// Updates / replaces a group's status. Returns the previous group status.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
pub async fn update_group_status(
&self,
group_id: GroupId,
status: GroupStatus,
) -> OptimizerResult<GroupStatus> {
// First retrieve the group record.
let mut group = self.get_group(group_id).await?.into_active_model();

// Update the group's status.
let old_status = group.status;
group.status = Set(status as u8 as i8);
group.update(&self.db).await?;

let old_status = match old_status.unwrap() {
0 => GroupStatus::InProgress,
1 => GroupStatus::Explored,
2 => GroupStatus::Optimized,
_ => panic!("encountered an invalid group status"),
};

Ok(old_status)
}

/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
/// winner's physical expression ID.
///
Expand All @@ -167,8 +227,32 @@ impl PersistentMemo {
group.update(&self.db).await?;

// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
let old = old_id.unwrap().map(PhysicalExpressionId);
Ok(old)
let old_id = old_id.unwrap().map(PhysicalExpressionId);
Ok(old_id)
}

/// Updates / replaces a group's parent group. Optionally returns the previous parent.
///
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
pub async fn update_group_parent(
&self,
group_id: GroupId,
parent_id: GroupId,
) -> OptimizerResult<Option<GroupId>> {
// First retrieve the group record.
let mut group = self.get_group(group_id).await?.into_active_model();

// Check that the parent group exists.
let _ = self.get_group(parent_id).await?;

// Update the group to point to the new parent.
let old_parent = group.parent_id;
group.parent_id = Set(Some(parent_id.0));
group.update(&self.db).await?;

// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
let old_parent = old_parent.unwrap().map(GroupId);
Ok(old_parent)
}

/// Adds a logical expression to an existing group via its ID.
Expand All @@ -192,10 +276,10 @@ impl PersistentMemo {
group_id: GroupId,
logical_expression: LogicalExpression,
children: &[GroupId],
) -> OptimizerResult<Result<LogicalExpressionId, LogicalExpressionId>> {
) -> OptimizerResult<Result<LogicalExpressionId, (GroupId, LogicalExpressionId)>> {
// Check if the expression already exists anywhere in the memo table.
if let Some(existing_id) = self
.is_duplicate_logical_expression(&logical_expression)
.is_duplicate_logical_expression(&logical_expression, children)
.await?
{
return Ok(Err(existing_id));
Expand Down Expand Up @@ -227,7 +311,15 @@ impl PersistentMemo {
// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let kind = new_expr.kind();
let hash = new_expr.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let hash = new_expr.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
id: NotSet,
Expand Down Expand Up @@ -285,8 +377,8 @@ impl PersistentMemo {
/// In order to prevent a large amount of duplicate work, the memo table must support duplicate
/// expression detection.
///
/// Returns `Some(expression_id)` if the memo table detects that the expression already exists,
/// and `None` otherwise.
/// Returns `Some((group_id, expression_id))` if the memo table detects that the expression
/// already exists, and `None` otherwise.
///
/// This function assumes that the child groups of the expression are currently roots of their
/// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input
Expand All @@ -296,13 +388,22 @@ impl PersistentMemo {
pub async fn is_duplicate_logical_expression(
&self,
logical_expression: &LogicalExpression,
) -> OptimizerResult<Option<LogicalExpressionId>> {
children: &[GroupId],
) -> OptimizerResult<Option<(GroupId, LogicalExpressionId)>> {
let model: logical_expression::Model = logical_expression.clone().into();

// Lookup all expressions that have the same fingerprint and kind. There may be false
// positives, but we will check for those next.
let kind = model.kind;
let fingerprint = logical_expression.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites);

// Filter first by the fingerprint, and then the kind.
// FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
Expand All @@ -323,11 +424,11 @@ impl PersistentMemo {
let mut match_id = None;
for potential_match in potential_matches {
let expr_id = LogicalExpressionId(potential_match.logical_expression_id);
let (_, expr) = self.get_logical_expression(expr_id).await?;
let (group_id, expr) = self.get_logical_expression(expr_id).await?;

// Check for an exact match.
if &expr == logical_expression {
match_id = Some(expr_id);
match_id = Some((group_id, expr_id));

// There should be at most one duplicate expression, so we can break here.
break;
Expand Down Expand Up @@ -359,18 +460,17 @@ impl PersistentMemo {
) -> OptimizerResult<Result<(GroupId, LogicalExpressionId), (GroupId, LogicalExpressionId)>>
{
// Check if the expression already exists in the memo table.
if let Some(existing_id) = self
.is_duplicate_logical_expression(&logical_expression)
if let Some((group_id, existing_id)) = self
.is_duplicate_logical_expression(&logical_expression, children)
.await?
{
let (group_id, _expr) = self.get_logical_expression(existing_id).await?;
return Ok(Err((group_id, existing_id)));
}

// The expression does not exist yet, so we need to create a new group and new expression.
let group = cascades_group::ActiveModel {
winner: Set(None),
is_optimized: Set(false),
status: Set(0), // `GroupStatus::InProgress` status.
..Default::default()
};

Expand Down Expand Up @@ -401,7 +501,15 @@ impl PersistentMemo {
// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let kind = new_expr.kind();
let hash = new_expr.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let hash = new_expr.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
id: NotSet,
Expand Down
72 changes: 70 additions & 2 deletions optd-mvp/src/memo/persistent/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ async fn test_simple_logical_duplicates() {
// Test `add_logical_expression_to_group`.
{
// Attempting to add a duplicate expression into the same group should also fail every time.
let logical_expression_id_2a = memo
let (group_id_2a, logical_expression_id_2a) = memo
.add_logical_expression_to_group(group_id, scan2a, &[])
.await
.unwrap()
.err()
.unwrap();
assert_eq!(group_id, group_id_2a);
assert_eq!(logical_expression_id, logical_expression_id_2a);

let logical_expression_id_2b = memo
let (group_id_2b, logical_expression_id_2b) = memo
.add_logical_expression_to_group(group_id, scan2b, &[])
.await
.unwrap()
.err()
.unwrap();
assert_eq!(group_id, group_id_2b);
assert_eq!(logical_expression_id, logical_expression_id_2b);
}

Expand Down Expand Up @@ -140,3 +142,69 @@ async fn test_simple_tree() {

memo.cleanup().await;
}

/// Tests basic group merging. See comments in the test itself for more information.
#[ignore]
#[tokio::test]
async fn test_simple_group_link() {
let memo = PersistentMemo::new().await;
memo.cleanup().await;

// Create two scan groups.
let scan1 = scan("t1".to_string());
let scan2 = scan("t2".to_string());
let (scan_id_1, _) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap();
let (scan_id_2, _) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap();

// Create two join expression that should be in the same group.
// Even though these are obviously the same expression (to humans), the fingerprints will be
// different, and so they will be put into different groups.
let join1 = join(scan_id_1, scan_id_2, "t1.a = t2.b".to_string());
let join2 = join(scan_id_2, scan_id_1, "t2.b = t1.a".to_string());
let join_unknown = join2.clone();

let (join_group_1, _) = memo
.add_group(join1, &[scan_id_1, scan_id_2])
.await
.unwrap()
.ok()
.unwrap();
let (join_group_2, join_expr_2) = memo
.add_group(join2, &[scan_id_2, scan_id_1])
.await
.unwrap()
.ok()
.unwrap();
assert_ne!(join_group_1, join_group_2);

// Assume that some rule was applied to `join1`, and it outputs something like `join_unknown`.
// The memo table will tell us that `join_unknown == join2`.
// Take note here that `join_unknown` is a clone of `join2`, not `join1`.
let (existing_group, not_actually_new_expr_id) = memo
.add_logical_expression_to_group(join_group_1, join_unknown, &[scan_id_2, scan_id_1])
.await
.unwrap()
.err()
.unwrap();
assert_eq!(existing_group, join_group_2);
assert_eq!(not_actually_new_expr_id, join_expr_2);

// The above tells the application that the expression already exists in the memo, specifically
// under `existing_group`. Thus, we should link these two groups together.
// Here, we arbitrarily choose to link group 1 into group 2.
memo.update_group_parent(join_group_1, join_group_2).await.unwrap();

let test_root_1 = memo.get_root_group(join_group_1).await.unwrap();
let test_root_2 = memo.get_root_group(join_group_2).await.unwrap();
assert_eq!(test_root_1, test_root_2);

// TODO(Connor)
//
// We now need to find all logical expressions that had group 1 (or whatever the root group of
// the set that group 1 belongs to is, in this case it is just group 1) as a child, and add a
// new fingerprint for each one that uses group 2 as a child instead.
//
// In order to do this, we need to iterate through every group in group 1's set.

memo.cleanup().await;
}
6 changes: 3 additions & 3 deletions optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*};
pub enum CascadesGroup {
Table,
Id,
Status,
Winner,
Cost,
IsOptimized,
ParentId,
}

Expand All @@ -92,16 +92,16 @@ impl MigrationTrait for Migration {
.table(CascadesGroup::Table)
.if_not_exists()
.col(pk_auto(CascadesGroup::Id))
.col(tiny_integer(CascadesGroup::Status))
.col(integer_null(CascadesGroup::Winner))
.col(big_unsigned_null(CascadesGroup::Cost))
.col(big_integer_null(CascadesGroup::Cost))
.foreign_key(
ForeignKey::create()
.from(CascadesGroup::Table, CascadesGroup::Winner)
.to(PhysicalExpression::Table, PhysicalExpression::Id)
.on_delete(ForeignKeyAction::SetNull)
.on_update(ForeignKeyAction::Cascade),
)
.col(boolean(CascadesGroup::IsOptimized))
.col(integer_null(CascadesGroup::ParentId))
.foreign_key(
ForeignKey::create()
Expand Down
Loading

0 comments on commit 9d9db84

Please sign in to comment.