Skip to content

Commit

Permalink
add union find and group merging support
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Nov 30, 2024
1 parent 0a0af6d commit 4b06b81
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 16 deletions.
2 changes: 1 addition & 1 deletion optd-mvp/src/entities/cascades_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*;
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub status: i8,
pub winner: Option<i32>,
pub cost: Option<i64>,
pub is_optimized: bool,
pub parent_id: Option<i32>,
}

Expand Down
8 changes: 8 additions & 0 deletions optd-mvp/src/memo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PhysicalExpressionId(pub i32);

/// A status enum representing the different states a group can be during query optimization.
#[repr(u8)]
pub enum GroupStatus {
InProgress = 0,
Explored = 1,
Optimized = 2,
}

/// The different kinds of errors that might occur while running operations on a memo table.
#[derive(Error, Debug)]
pub enum MemoError {
Expand Down
140 changes: 131 additions & 9 deletions optd-mvp/src/memo/persistent/implementation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::PersistentMemo;
use crate::{
entities::*,
expression::{LogicalExpression, PhysicalExpression},
memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId},
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
OptimizerResult, DATABASE_URL,
};
use sea_orm::{
Expand Down Expand Up @@ -66,6 +66,40 @@ impl PersistentMemo {
.ok_or(MemoError::UnknownGroup(group_id))?)
}

/// Retrieves the root / canonical group ID of the given group ID.
///
/// The groups form a union find / disjoint set parent pointer forest, where group merging
/// causes two trees to merge.
///
/// This function uses the path compression optimization, which amortizes the cost to a single
/// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult<GroupId> {
let mut curr_group = self.get_group(group_id).await?;

// Traverse up the path and find the root group, keeping track of groups we have visited.
let mut path = vec![];
loop {
let Some(parent_id) = curr_group.parent_id else {
break;
};

let next_group = self.get_group(GroupId(parent_id)).await?;
path.push(curr_group);
curr_group = next_group;
}

let root_id = GroupId(curr_group.id);

// Path Compression Optimization:
// For every group along the path that we walked, set their parent id pointer to the root.
// This allows for an amortized O(1) cost for `get_root_group`.
for group in path {
self.update_group_parent(GroupId(group.id), root_id).await?;
}

Ok(root_id)
}

/// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
///
/// If the physical expression does not exist, returns a
Expand Down Expand Up @@ -146,6 +180,32 @@ impl PersistentMemo {
Ok(children)
}

/// Updates / replaces a group's status. Returns the previous group status.
///
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
pub async fn update_group_status(
&self,
group_id: GroupId,
status: GroupStatus,
) -> OptimizerResult<GroupStatus> {
// First retrieve the group record.
let mut group = self.get_group(group_id).await?.into_active_model();

// Update the group's status.
let old_status = group.status;
group.status = Set(status as u8 as i8);
group.update(&self.db).await?;

let old_status = match old_status.unwrap() {
0 => GroupStatus::InProgress,
1 => GroupStatus::Explored,
2 => GroupStatus::Optimized,
_ => panic!("encountered an invalid group status"),
};

Ok(old_status)
}

/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
/// winner's physical expression ID.
///
Expand All @@ -167,8 +227,45 @@ impl PersistentMemo {
group.update(&self.db).await?;

// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
let old = old_id.unwrap().map(PhysicalExpressionId);
Ok(old)
let old_id = old_id.unwrap().map(PhysicalExpressionId);
Ok(old_id)
}

/// Updates / replaces a group's parent group. Optionally returns the previous parent.
///
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
pub async fn update_group_parent(
&self,
group_id: GroupId,
parent_id: GroupId,
) -> OptimizerResult<Option<GroupId>> {
// First retrieve the group record.
let mut group = self.get_group(group_id).await?.into_active_model();

// Check that the parent group exists.
let _ = self.get_group(parent_id).await?;

// Update the group to point to the new parent.
let old_parent = group.parent_id;
group.parent_id = Set(Some(parent_id.0));
group.update(&self.db).await?;

// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
let old_parent = old_parent.unwrap().map(GroupId);
Ok(old_parent)
}

/// Merges two groups sets together. Returns the new root group of the unioned sets.
///
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
///
/// TODO use union by rank / size as an optimization?
pub async fn merge_groups(&self, group1: GroupId, group2: GroupId) -> OptimizerResult<GroupId> {
// Without tracking the size of each of the groups, it is arbitrary which group is better to
// merge into the other. So we will arbitrarily choose `group1` to merge into `group2`.
self.update_group_parent(group1, group2).await?;

Ok(group2)
}

/// Adds a logical expression to an existing group via its ID.
Expand All @@ -195,7 +292,7 @@ impl PersistentMemo {
) -> OptimizerResult<Result<LogicalExpressionId, LogicalExpressionId>> {
// Check if the expression already exists anywhere in the memo table.
if let Some(existing_id) = self
.is_duplicate_logical_expression(&logical_expression)
.is_duplicate_logical_expression(&logical_expression, children)
.await?
{
return Ok(Err(existing_id));
Expand Down Expand Up @@ -227,7 +324,15 @@ impl PersistentMemo {
// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let kind = new_expr.kind();
let hash = new_expr.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let hash = new_expr.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
id: NotSet,
Expand Down Expand Up @@ -296,13 +401,22 @@ impl PersistentMemo {
pub async fn is_duplicate_logical_expression(
&self,
logical_expression: &LogicalExpression,
children: &[GroupId],
) -> OptimizerResult<Option<LogicalExpressionId>> {
let model: logical_expression::Model = logical_expression.clone().into();

// Lookup all expressions that have the same fingerprint and kind. There may be false
// positives, but we will check for those next.
let kind = model.kind;
let fingerprint = logical_expression.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites);

// Filter first by the fingerprint, and then the kind.
// FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
Expand Down Expand Up @@ -360,7 +474,7 @@ impl PersistentMemo {
{
// Check if the expression already exists in the memo table.
if let Some(existing_id) = self
.is_duplicate_logical_expression(&logical_expression)
.is_duplicate_logical_expression(&logical_expression, children)
.await?
{
let (group_id, _expr) = self.get_logical_expression(existing_id).await?;
Expand All @@ -370,7 +484,7 @@ impl PersistentMemo {
// The expression does not exist yet, so we need to create a new group and new expression.
let group = cascades_group::ActiveModel {
winner: Set(None),
is_optimized: Set(false),
status: Set(0), // `GroupStatus::InProgress` status.
..Default::default()
};

Expand Down Expand Up @@ -401,7 +515,15 @@ impl PersistentMemo {
// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let kind = new_expr.kind();
let hash = new_expr.fingerprint();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
// groups of the children instead of the child ID themselves.
let mut rewrites = vec![];
for &child_id in children {
let root_id = self.get_root_group(child_id).await?;
rewrites.push((child_id, root_id));
}
let hash = new_expr.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
id: NotSet,
Expand Down
6 changes: 3 additions & 3 deletions optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*};
pub enum CascadesGroup {
Table,
Id,
Status,
Winner,
Cost,
IsOptimized,
ParentId,
}

Expand All @@ -92,16 +92,16 @@ impl MigrationTrait for Migration {
.table(CascadesGroup::Table)
.if_not_exists()
.col(pk_auto(CascadesGroup::Id))
.col(tiny_integer(CascadesGroup::Status))
.col(integer_null(CascadesGroup::Winner))
.col(big_unsigned_null(CascadesGroup::Cost))
.col(big_integer_null(CascadesGroup::Cost))
.foreign_key(
ForeignKey::create()
.from(CascadesGroup::Table, CascadesGroup::Winner)
.to(PhysicalExpression::Table, PhysicalExpression::Id)
.on_delete(ForeignKeyAction::SetNull)
.on_update(ForeignKeyAction::Cascade),
)
.col(boolean(CascadesGroup::IsOptimized))
.col(integer_null(CascadesGroup::ParentId))
.foreign_key(
ForeignKey::create()
Expand Down
6 changes: 3 additions & 3 deletions optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ impl MigrationTrait for Migration {
.table(Fingerprint::Table)
.if_not_exists()
.col(pk_auto(Fingerprint::Id))
.col(unsigned(Fingerprint::LogicalExpressionId))
.col(integer(Fingerprint::LogicalExpressionId))
.foreign_key(
ForeignKey::create()
.from(Fingerprint::Table, Fingerprint::LogicalExpressionId)
.to(LogicalExpression::Table, LogicalExpression::Id)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.col(small_unsigned(Fingerprint::Kind))
.col(big_unsigned(Fingerprint::Hash))
.col(small_integer(Fingerprint::Kind))
.col(big_integer(Fingerprint::Hash))
.to_owned(),
)
.await
Expand Down

0 comments on commit 4b06b81

Please sign in to comment.