From 4b06b81bd7d64fdb72d177a677d7ebf4ac7bbbba Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Sat, 30 Nov 2024 16:24:39 -0500 Subject: [PATCH] add union find and group merging support --- optd-mvp/src/entities/cascades_group.rs | 2 +- optd-mvp/src/memo/mod.rs | 8 + .../src/memo/persistent/implementation.rs | 140 ++++++++++++++++-- .../memo/m20241127_000001_cascades_group.rs | 6 +- .../memo/m20241127_000001_fingerprint.rs | 6 +- 5 files changed, 146 insertions(+), 16 deletions(-) diff --git a/optd-mvp/src/entities/cascades_group.rs b/optd-mvp/src/entities/cascades_group.rs index 9c2ba83..62e1835 100644 --- a/optd-mvp/src/entities/cascades_group.rs +++ b/optd-mvp/src/entities/cascades_group.rs @@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*; pub struct Model { #[sea_orm(primary_key)] pub id: i32, + pub status: i8, pub winner: Option, pub cost: Option, - pub is_optimized: bool, pub parent_id: Option, } diff --git a/optd-mvp/src/memo/mod.rs b/optd-mvp/src/memo/mod.rs index fbf23a2..83a821f 100644 --- a/optd-mvp/src/memo/mod.rs +++ b/optd-mvp/src/memo/mod.rs @@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32); #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct PhysicalExpressionId(pub i32); +/// A status enum representing the different states a group can be during query optimization. +#[repr(u8)] +pub enum GroupStatus { + InProgress = 0, + Explored = 1, + Optimized = 2, +} + /// The different kinds of errors that might occur while running operations on a memo table. #[derive(Error, Debug)] pub enum MemoError { diff --git a/optd-mvp/src/memo/persistent/implementation.rs b/optd-mvp/src/memo/persistent/implementation.rs index 4fc7048..e75bc39 100644 --- a/optd-mvp/src/memo/persistent/implementation.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -10,7 +10,7 @@ use super::PersistentMemo; use crate::{ entities::*, expression::{LogicalExpression, PhysicalExpression}, - memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId}, + memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId}, OptimizerResult, DATABASE_URL, }; use sea_orm::{ @@ -66,6 +66,40 @@ impl PersistentMemo { .ok_or(MemoError::UnknownGroup(group_id))?) } + /// Retrieves the root / canonical group ID of the given group ID. + /// + /// The groups form a union find / disjoint set parent pointer forest, where group merging + /// causes two trees to merge. + /// + /// This function uses the path compression optimization, which amortizes the cost to a single + /// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip). + pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult { + let mut curr_group = self.get_group(group_id).await?; + + // Traverse up the path and find the root group, keeping track of groups we have visited. + let mut path = vec![]; + loop { + let Some(parent_id) = curr_group.parent_id else { + break; + }; + + let next_group = self.get_group(GroupId(parent_id)).await?; + path.push(curr_group); + curr_group = next_group; + } + + let root_id = GroupId(curr_group.id); + + // Path Compression Optimization: + // For every group along the path that we walked, set their parent id pointer to the root. + // This allows for an amortized O(1) cost for `get_root_group`. + for group in path { + self.update_group_parent(GroupId(group.id), root_id).await?; + } + + Ok(root_id) + } + /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`]. /// /// If the physical expression does not exist, returns a @@ -146,6 +180,32 @@ impl PersistentMemo { Ok(children) } + /// Updates / replaces a group's status. Returns the previous group status. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn update_group_status( + &self, + group_id: GroupId, + status: GroupStatus, + ) -> OptimizerResult { + // First retrieve the group record. + let mut group = self.get_group(group_id).await?.into_active_model(); + + // Update the group's status. + let old_status = group.status; + group.status = Set(status as u8 as i8); + group.update(&self.db).await?; + + let old_status = match old_status.unwrap() { + 0 => GroupStatus::InProgress, + 1 => GroupStatus::Explored, + 2 => GroupStatus::Optimized, + _ => panic!("encountered an invalid group status"), + }; + + Ok(old_status) + } + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous /// winner's physical expression ID. /// @@ -167,8 +227,45 @@ impl PersistentMemo { group.update(&self.db).await?; // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. - let old = old_id.unwrap().map(PhysicalExpressionId); - Ok(old) + let old_id = old_id.unwrap().map(PhysicalExpressionId); + Ok(old_id) + } + + /// Updates / replaces a group's parent group. Optionally returns the previous parent. + /// + /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn update_group_parent( + &self, + group_id: GroupId, + parent_id: GroupId, + ) -> OptimizerResult> { + // First retrieve the group record. + let mut group = self.get_group(group_id).await?.into_active_model(); + + // Check that the parent group exists. + let _ = self.get_group(parent_id).await?; + + // Update the group to point to the new parent. + let old_parent = group.parent_id; + group.parent_id = Set(Some(parent_id.0)); + group.update(&self.db).await?; + + // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. + let old_parent = old_parent.unwrap().map(GroupId); + Ok(old_parent) + } + + /// Merges two groups sets together. Returns the new root group of the unioned sets. + /// + /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// TODO use union by rank / size as an optimization? + pub async fn merge_groups(&self, group1: GroupId, group2: GroupId) -> OptimizerResult { + // Without tracking the size of each of the groups, it is arbitrary which group is better to + // merge into the other. So we will arbitrarily choose `group1` to merge into `group2`. + self.update_group_parent(group1, group2).await?; + + Ok(group2) } /// Adds a logical expression to an existing group via its ID. @@ -195,7 +292,7 @@ impl PersistentMemo { ) -> OptimizerResult> { // Check if the expression already exists anywhere in the memo table. if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) + .is_duplicate_logical_expression(&logical_expression, children) .await? { return Ok(Err(existing_id)); @@ -227,7 +324,15 @@ impl PersistentMemo { // Finally, insert the fingerprint of the logical expression as well. let new_expr: LogicalExpression = new_model.into(); let kind = new_expr.kind(); - let hash = new_expr.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let hash = new_expr.fingerprint_with_rewrite(&rewrites); let fingerprint = fingerprint::ActiveModel { id: NotSet, @@ -296,13 +401,22 @@ impl PersistentMemo { pub async fn is_duplicate_logical_expression( &self, logical_expression: &LogicalExpression, + children: &[GroupId], ) -> OptimizerResult> { let model: logical_expression::Model = logical_expression.clone().into(); // Lookup all expressions that have the same fingerprint and kind. There may be false // positives, but we will check for those next. let kind = model.kind; - let fingerprint = logical_expression.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites); // Filter first by the fingerprint, and then the kind. // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the @@ -360,7 +474,7 @@ impl PersistentMemo { { // Check if the expression already exists in the memo table. if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) + .is_duplicate_logical_expression(&logical_expression, children) .await? { let (group_id, _expr) = self.get_logical_expression(existing_id).await?; @@ -370,7 +484,7 @@ impl PersistentMemo { // The expression does not exist yet, so we need to create a new group and new expression. let group = cascades_group::ActiveModel { winner: Set(None), - is_optimized: Set(false), + status: Set(0), // `GroupStatus::InProgress` status. ..Default::default() }; @@ -401,7 +515,15 @@ impl PersistentMemo { // Finally, insert the fingerprint of the logical expression as well. let new_expr: LogicalExpression = new_model.into(); let kind = new_expr.kind(); - let hash = new_expr.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let hash = new_expr.fingerprint_with_rewrite(&rewrites); let fingerprint = fingerprint::ActiveModel { id: NotSet, diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs b/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs index 3a0e7d0..abaa829 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs @@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*}; pub enum CascadesGroup { Table, Id, + Status, Winner, Cost, - IsOptimized, ParentId, } @@ -92,8 +92,9 @@ impl MigrationTrait for Migration { .table(CascadesGroup::Table) .if_not_exists() .col(pk_auto(CascadesGroup::Id)) + .col(tiny_integer(CascadesGroup::Status)) .col(integer_null(CascadesGroup::Winner)) - .col(big_unsigned_null(CascadesGroup::Cost)) + .col(big_integer_null(CascadesGroup::Cost)) .foreign_key( ForeignKey::create() .from(CascadesGroup::Table, CascadesGroup::Winner) @@ -101,7 +102,6 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::SetNull) .on_update(ForeignKeyAction::Cascade), ) - .col(boolean(CascadesGroup::IsOptimized)) .col(integer_null(CascadesGroup::ParentId)) .foreign_key( ForeignKey::create() diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs b/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs index 4a828b8..e153b9e 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs @@ -26,7 +26,7 @@ impl MigrationTrait for Migration { .table(Fingerprint::Table) .if_not_exists() .col(pk_auto(Fingerprint::Id)) - .col(unsigned(Fingerprint::LogicalExpressionId)) + .col(integer(Fingerprint::LogicalExpressionId)) .foreign_key( ForeignKey::create() .from(Fingerprint::Table, Fingerprint::LogicalExpressionId) @@ -34,8 +34,8 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ) - .col(small_unsigned(Fingerprint::Kind)) - .col(big_unsigned(Fingerprint::Hash)) + .col(small_integer(Fingerprint::Kind)) + .col(big_integer(Fingerprint::Hash)) .to_owned(), ) .await