From 9d9db841e81a998ed3340f72b595e41501b35eae Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Sat, 30 Nov 2024 16:24:39 -0500 Subject: [PATCH] add union find and group merging support --- optd-mvp/DESIGN.md | 0 optd-mvp/{README.md => entities.md} | 0 optd-mvp/src/entities/cascades_group.rs | 2 +- optd-mvp/src/memo/mod.rs | 8 + .../src/memo/persistent/implementation.rs | 142 +++++++++++++++--- optd-mvp/src/memo/persistent/tests.rs | 72 ++++++++- .../memo/m20241127_000001_cascades_group.rs | 6 +- .../memo/m20241127_000001_fingerprint.rs | 6 +- 8 files changed, 210 insertions(+), 26 deletions(-) create mode 100644 optd-mvp/DESIGN.md rename optd-mvp/{README.md => entities.md} (100%) diff --git a/optd-mvp/DESIGN.md b/optd-mvp/DESIGN.md new file mode 100644 index 0000000..e69de29 diff --git a/optd-mvp/README.md b/optd-mvp/entities.md similarity index 100% rename from optd-mvp/README.md rename to optd-mvp/entities.md diff --git a/optd-mvp/src/entities/cascades_group.rs b/optd-mvp/src/entities/cascades_group.rs index 9c2ba83..62e1835 100644 --- a/optd-mvp/src/entities/cascades_group.rs +++ b/optd-mvp/src/entities/cascades_group.rs @@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*; pub struct Model { #[sea_orm(primary_key)] pub id: i32, + pub status: i8, pub winner: Option, pub cost: Option, - pub is_optimized: bool, pub parent_id: Option, } diff --git a/optd-mvp/src/memo/mod.rs b/optd-mvp/src/memo/mod.rs index fbf23a2..83a821f 100644 --- a/optd-mvp/src/memo/mod.rs +++ b/optd-mvp/src/memo/mod.rs @@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32); #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct PhysicalExpressionId(pub i32); +/// A status enum representing the different states a group can be during query optimization. +#[repr(u8)] +pub enum GroupStatus { + InProgress = 0, + Explored = 1, + Optimized = 2, +} + /// The different kinds of errors that might occur while running operations on a memo table. #[derive(Error, Debug)] pub enum MemoError { diff --git a/optd-mvp/src/memo/persistent/implementation.rs b/optd-mvp/src/memo/persistent/implementation.rs index 4fc7048..002893a 100644 --- a/optd-mvp/src/memo/persistent/implementation.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -10,7 +10,7 @@ use super::PersistentMemo; use crate::{ entities::*, expression::{LogicalExpression, PhysicalExpression}, - memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId}, + memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId}, OptimizerResult, DATABASE_URL, }; use sea_orm::{ @@ -66,6 +66,40 @@ impl PersistentMemo { .ok_or(MemoError::UnknownGroup(group_id))?) } + /// Retrieves the root / canonical group ID of the given group ID. + /// + /// The groups form a union find / disjoint set parent pointer forest, where group merging + /// causes two trees to merge. + /// + /// This function uses the path compression optimization, which amortizes the cost to a single + /// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip). + pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult { + let mut curr_group = self.get_group(group_id).await?; + + // Traverse up the path and find the root group, keeping track of groups we have visited. + let mut path = vec![]; + loop { + let Some(parent_id) = curr_group.parent_id else { + break; + }; + + let next_group = self.get_group(GroupId(parent_id)).await?; + path.push(curr_group); + curr_group = next_group; + } + + let root_id = GroupId(curr_group.id); + + // Path Compression Optimization: + // For every group along the path that we walked, set their parent id pointer to the root. + // This allows for an amortized O(1) cost for `get_root_group`. + for group in path { + self.update_group_parent(GroupId(group.id), root_id).await?; + } + + Ok(root_id) + } + /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`]. /// /// If the physical expression does not exist, returns a @@ -146,6 +180,32 @@ impl PersistentMemo { Ok(children) } + /// Updates / replaces a group's status. Returns the previous group status. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn update_group_status( + &self, + group_id: GroupId, + status: GroupStatus, + ) -> OptimizerResult { + // First retrieve the group record. + let mut group = self.get_group(group_id).await?.into_active_model(); + + // Update the group's status. + let old_status = group.status; + group.status = Set(status as u8 as i8); + group.update(&self.db).await?; + + let old_status = match old_status.unwrap() { + 0 => GroupStatus::InProgress, + 1 => GroupStatus::Explored, + 2 => GroupStatus::Optimized, + _ => panic!("encountered an invalid group status"), + }; + + Ok(old_status) + } + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous /// winner's physical expression ID. /// @@ -167,8 +227,32 @@ impl PersistentMemo { group.update(&self.db).await?; // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. - let old = old_id.unwrap().map(PhysicalExpressionId); - Ok(old) + let old_id = old_id.unwrap().map(PhysicalExpressionId); + Ok(old_id) + } + + /// Updates / replaces a group's parent group. Optionally returns the previous parent. + /// + /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn update_group_parent( + &self, + group_id: GroupId, + parent_id: GroupId, + ) -> OptimizerResult> { + // First retrieve the group record. + let mut group = self.get_group(group_id).await?.into_active_model(); + + // Check that the parent group exists. + let _ = self.get_group(parent_id).await?; + + // Update the group to point to the new parent. + let old_parent = group.parent_id; + group.parent_id = Set(Some(parent_id.0)); + group.update(&self.db).await?; + + // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. + let old_parent = old_parent.unwrap().map(GroupId); + Ok(old_parent) } /// Adds a logical expression to an existing group via its ID. @@ -192,10 +276,10 @@ impl PersistentMemo { group_id: GroupId, logical_expression: LogicalExpression, children: &[GroupId], - ) -> OptimizerResult> { + ) -> OptimizerResult> { // Check if the expression already exists anywhere in the memo table. if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) + .is_duplicate_logical_expression(&logical_expression, children) .await? { return Ok(Err(existing_id)); @@ -227,7 +311,15 @@ impl PersistentMemo { // Finally, insert the fingerprint of the logical expression as well. let new_expr: LogicalExpression = new_model.into(); let kind = new_expr.kind(); - let hash = new_expr.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let hash = new_expr.fingerprint_with_rewrite(&rewrites); let fingerprint = fingerprint::ActiveModel { id: NotSet, @@ -285,8 +377,8 @@ impl PersistentMemo { /// In order to prevent a large amount of duplicate work, the memo table must support duplicate /// expression detection. /// - /// Returns `Some(expression_id)` if the memo table detects that the expression already exists, - /// and `None` otherwise. + /// Returns `Some((group_id, expression_id))` if the memo table detects that the expression + /// already exists, and `None` otherwise. /// /// This function assumes that the child groups of the expression are currently roots of their /// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input @@ -296,13 +388,22 @@ impl PersistentMemo { pub async fn is_duplicate_logical_expression( &self, logical_expression: &LogicalExpression, - ) -> OptimizerResult> { + children: &[GroupId], + ) -> OptimizerResult> { let model: logical_expression::Model = logical_expression.clone().into(); // Lookup all expressions that have the same fingerprint and kind. There may be false // positives, but we will check for those next. let kind = model.kind; - let fingerprint = logical_expression.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites); // Filter first by the fingerprint, and then the kind. // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the @@ -323,11 +424,11 @@ impl PersistentMemo { let mut match_id = None; for potential_match in potential_matches { let expr_id = LogicalExpressionId(potential_match.logical_expression_id); - let (_, expr) = self.get_logical_expression(expr_id).await?; + let (group_id, expr) = self.get_logical_expression(expr_id).await?; // Check for an exact match. if &expr == logical_expression { - match_id = Some(expr_id); + match_id = Some((group_id, expr_id)); // There should be at most one duplicate expression, so we can break here. break; @@ -359,18 +460,17 @@ impl PersistentMemo { ) -> OptimizerResult> { // Check if the expression already exists in the memo table. - if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) + if let Some((group_id, existing_id)) = self + .is_duplicate_logical_expression(&logical_expression, children) .await? { - let (group_id, _expr) = self.get_logical_expression(existing_id).await?; return Ok(Err((group_id, existing_id))); } // The expression does not exist yet, so we need to create a new group and new expression. let group = cascades_group::ActiveModel { winner: Set(None), - is_optimized: Set(false), + status: Set(0), // `GroupStatus::InProgress` status. ..Default::default() }; @@ -401,7 +501,15 @@ impl PersistentMemo { // Finally, insert the fingerprint of the logical expression as well. let new_expr: LogicalExpression = new_model.into(); let kind = new_expr.kind(); - let hash = new_expr.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let hash = new_expr.fingerprint_with_rewrite(&rewrites); let fingerprint = fingerprint::ActiveModel { id: NotSet, diff --git a/optd-mvp/src/memo/persistent/tests.rs b/optd-mvp/src/memo/persistent/tests.rs index f3afea6..ae2a375 100644 --- a/optd-mvp/src/memo/persistent/tests.rs +++ b/optd-mvp/src/memo/persistent/tests.rs @@ -34,20 +34,22 @@ async fn test_simple_logical_duplicates() { // Test `add_logical_expression_to_group`. { // Attempting to add a duplicate expression into the same group should also fail every time. - let logical_expression_id_2a = memo + let (group_id_2a, logical_expression_id_2a) = memo .add_logical_expression_to_group(group_id, scan2a, &[]) .await .unwrap() .err() .unwrap(); + assert_eq!(group_id, group_id_2a); assert_eq!(logical_expression_id, logical_expression_id_2a); - let logical_expression_id_2b = memo + let (group_id_2b, logical_expression_id_2b) = memo .add_logical_expression_to_group(group_id, scan2b, &[]) .await .unwrap() .err() .unwrap(); + assert_eq!(group_id, group_id_2b); assert_eq!(logical_expression_id, logical_expression_id_2b); } @@ -140,3 +142,69 @@ async fn test_simple_tree() { memo.cleanup().await; } + +/// Tests basic group merging. See comments in the test itself for more information. +#[ignore] +#[tokio::test] +async fn test_simple_group_link() { + let memo = PersistentMemo::new().await; + memo.cleanup().await; + + // Create two scan groups. + let scan1 = scan("t1".to_string()); + let scan2 = scan("t2".to_string()); + let (scan_id_1, _) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap(); + let (scan_id_2, _) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap(); + + // Create two join expression that should be in the same group. + // Even though these are obviously the same expression (to humans), the fingerprints will be + // different, and so they will be put into different groups. + let join1 = join(scan_id_1, scan_id_2, "t1.a = t2.b".to_string()); + let join2 = join(scan_id_2, scan_id_1, "t2.b = t1.a".to_string()); + let join_unknown = join2.clone(); + + let (join_group_1, _) = memo + .add_group(join1, &[scan_id_1, scan_id_2]) + .await + .unwrap() + .ok() + .unwrap(); + let (join_group_2, join_expr_2) = memo + .add_group(join2, &[scan_id_2, scan_id_1]) + .await + .unwrap() + .ok() + .unwrap(); + assert_ne!(join_group_1, join_group_2); + + // Assume that some rule was applied to `join1`, and it outputs something like `join_unknown`. + // The memo table will tell us that `join_unknown == join2`. + // Take note here that `join_unknown` is a clone of `join2`, not `join1`. + let (existing_group, not_actually_new_expr_id) = memo + .add_logical_expression_to_group(join_group_1, join_unknown, &[scan_id_2, scan_id_1]) + .await + .unwrap() + .err() + .unwrap(); + assert_eq!(existing_group, join_group_2); + assert_eq!(not_actually_new_expr_id, join_expr_2); + + // The above tells the application that the expression already exists in the memo, specifically + // under `existing_group`. Thus, we should link these two groups together. + // Here, we arbitrarily choose to link group 1 into group 2. + memo.update_group_parent(join_group_1, join_group_2).await.unwrap(); + + let test_root_1 = memo.get_root_group(join_group_1).await.unwrap(); + let test_root_2 = memo.get_root_group(join_group_2).await.unwrap(); + assert_eq!(test_root_1, test_root_2); + + // TODO(Connor) + // + // We now need to find all logical expressions that had group 1 (or whatever the root group of + // the set that group 1 belongs to is, in this case it is just group 1) as a child, and add a + // new fingerprint for each one that uses group 2 as a child instead. + // + // In order to do this, we need to iterate through every group in group 1's set. + + memo.cleanup().await; +} diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs b/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs index 3a0e7d0..abaa829 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs @@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*}; pub enum CascadesGroup { Table, Id, + Status, Winner, Cost, - IsOptimized, ParentId, } @@ -92,8 +92,9 @@ impl MigrationTrait for Migration { .table(CascadesGroup::Table) .if_not_exists() .col(pk_auto(CascadesGroup::Id)) + .col(tiny_integer(CascadesGroup::Status)) .col(integer_null(CascadesGroup::Winner)) - .col(big_unsigned_null(CascadesGroup::Cost)) + .col(big_integer_null(CascadesGroup::Cost)) .foreign_key( ForeignKey::create() .from(CascadesGroup::Table, CascadesGroup::Winner) @@ -101,7 +102,6 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::SetNull) .on_update(ForeignKeyAction::Cascade), ) - .col(boolean(CascadesGroup::IsOptimized)) .col(integer_null(CascadesGroup::ParentId)) .foreign_key( ForeignKey::create() diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs b/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs index 4a828b8..e153b9e 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs @@ -26,7 +26,7 @@ impl MigrationTrait for Migration { .table(Fingerprint::Table) .if_not_exists() .col(pk_auto(Fingerprint::Id)) - .col(unsigned(Fingerprint::LogicalExpressionId)) + .col(integer(Fingerprint::LogicalExpressionId)) .foreign_key( ForeignKey::create() .from(Fingerprint::Table, Fingerprint::LogicalExpressionId) @@ -34,8 +34,8 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ) - .col(small_unsigned(Fingerprint::Kind)) - .col(big_unsigned(Fingerprint::Hash)) + .col(small_integer(Fingerprint::Kind)) + .col(big_integer(Fingerprint::Hash)) .to_owned(), ) .await