diff --git a/Cargo.lock b/Cargo.lock index 3059383..8acb13b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -710,6 +710,15 @@ dependencies = [ "slab", ] +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + [[package]] name = "generic-array" version = "0.14.4" @@ -1146,6 +1155,7 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", + "fxhash", "sea-orm", "sea-orm-migration", "serde", diff --git a/optd-mvp/Cargo.toml b/optd-mvp/Cargo.toml index 3b72407..f4a3a62 100644 --- a/optd-mvp/Cargo.toml +++ b/optd-mvp/Cargo.toml @@ -20,8 +20,10 @@ serde_json = "1.0.118" # Support `Hash` on `serde_json::Value` in "1.0.118". tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } trait-variant = "0.1.2" # Support `make(Send)` syntax in "0.1.2". thiserror = "2.0" +fxhash = "0.2" # Pin more recent versions for `-Zminimal-versions`. async-trait = "0.1.43" # Remove lifetime parameter from "0.1.42". async-stream = "0.3.1" # Fix unsatisfied trait bound from "0.3.0". strum = "0.26.0" # Fix `std::marker::Sized` from "0.25.0". + diff --git a/optd-mvp/src/entities/logical_children.rs b/optd-mvp/src/entities/logical_children.rs index 120641f..067eaa7 100644 --- a/optd-mvp/src/entities/logical_children.rs +++ b/optd-mvp/src/entities/logical_children.rs @@ -23,7 +23,7 @@ pub enum Relation { CascadesGroup, #[sea_orm( belongs_to = "super::logical_expression::Entity", - from = "Column::GroupId", + from = "Column::LogicalExpressionId", to = "super::logical_expression::Column::Id", on_update = "Cascade", on_delete = "Cascade" diff --git a/optd-mvp/src/entities/prelude.rs b/optd-mvp/src/entities/prelude.rs index 5619363..bf6879b 100644 --- a/optd-mvp/src/entities/prelude.rs +++ b/optd-mvp/src/entities/prelude.rs @@ -1,5 +1,7 @@ //! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 +#![allow(unused_imports)] + pub use super::cascades_group::Entity as CascadesGroup; pub use super::fingerprint::Entity as Fingerprint; pub use super::logical_children::Entity as LogicalChildren; diff --git a/optd-mvp/src/expression/logical_expression.rs b/optd-mvp/src/expression/logical_expression.rs index c87b055..7c3362d 100644 --- a/optd-mvp/src/expression/logical_expression.rs +++ b/optd-mvp/src/expression/logical_expression.rs @@ -1,37 +1,91 @@ //! Definition of logical expressions / relations in the Cascades query optimization framework. //! -//! FIXME: All fields are placeholders, and group IDs are just represented as i32 for now. -//! FIXME: Representation needs to know how to "rewrite" child group IDs to whatever a fingerprint -//! will need. +//! FIXME: All fields are placeholders. //! -//! TODO figure out if each relation should be in a different submodule. +//! TODO Remove dead code. +//! TODO Figure out if each relation should be in a different submodule. //! TODO This entire file is a WIP. -use crate::entities::*; +#![allow(dead_code)] + +use crate::{entities::*, memo::GroupId}; +use fxhash::hash; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum LogicalExpression { Scan(Scan), Filter(Filter), Join(Join), } -#[derive(Serialize, Deserialize, Clone, Debug)] +/// FIXME: Figure out how to make everything unsigned instead of signed. +impl LogicalExpression { + pub fn kind(&self) -> i16 { + match self { + LogicalExpression::Scan(_) => 0, + LogicalExpression::Filter(_) => 1, + LogicalExpression::Join(_) => 2, + } + } + + /// Definitions of custom fingerprinting strategies for each kind of logical expression. + pub fn fingerprint(&self) -> i64 { + self.fingerprint_with_rewrite(&[]) + } + + /// Calculates the fingerprint of a given expression, but replaces all of the children group IDs + /// with a new group ID if it is listed in the input `rewrites` list. + /// + /// TODO Allow each expression to implement a trait that does this. + pub fn fingerprint_with_rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> i64 { + // Closure that rewrites a group ID if needed. + let rewrite = |x: GroupId| { + if rewrites.is_empty() { + return x; + } + + if let Some(i) = rewrites.iter().position(|(curr, _new)| &x == curr) { + assert_eq!(rewrites[i].0, x); + rewrites[i].1 + } else { + x + } + }; + + let kind = self.kind() as u16 as usize; + let hash = match self { + LogicalExpression::Scan(scan) => hash(scan.table_schema.as_str()), + LogicalExpression::Filter(filter) => { + hash(&rewrite(filter.child).0) ^ hash(filter.expression.as_str()) + } + LogicalExpression::Join(join) => { + hash(&rewrite(join.left).0) + ^ hash(&rewrite(join.right).0) + ^ hash(join.expression.as_str()) + } + }; + + // Mask out the bottom 16 bits of `hash` and replace them with `kind`. + ((hash & !0xFFFF) | kind) as i64 + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct Scan { table_schema: String, } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct Filter { - child: i32, + child: GroupId, expression: String, } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct Join { - left: i32, - right: i32, + left: GroupId, + right: GroupId, expression: String, } @@ -71,17 +125,18 @@ impl From for logical_expression::Model { } } + let kind = value.kind(); match value { LogicalExpression::Scan(scan) => create_logical_expression( - 0, + kind, serde_json::to_value(scan).expect("unable to serialize logical `Scan`"), ), LogicalExpression::Filter(filter) => create_logical_expression( - 1, + kind, serde_json::to_value(filter).expect("unable to serialize logical `Filter`"), ), LogicalExpression::Join(join) => create_logical_expression( - 2, + kind, serde_json::to_value(join).expect("unable to serialize logical `Join`"), ), } @@ -94,24 +149,28 @@ pub use build::*; #[cfg(test)] mod build { use super::*; - use crate::expression::Expression; + use crate::expression::LogicalExpression; - pub fn scan(table_schema: String) -> Expression { - Expression::Logical(LogicalExpression::Scan(Scan { table_schema })) + pub fn scan(table_schema: String) -> LogicalExpression { + LogicalExpression::Scan(Scan { table_schema }) } - pub fn filter(child_group: i32, expression: String) -> Expression { - Expression::Logical(LogicalExpression::Filter(Filter { + pub fn filter(child_group: GroupId, expression: String) -> LogicalExpression { + LogicalExpression::Filter(Filter { child: child_group, expression, - })) + }) } - pub fn join(left_group: i32, right_group: i32, expression: String) -> Expression { - Expression::Logical(LogicalExpression::Join(Join { + pub fn join( + left_group: GroupId, + right_group: GroupId, + expression: String, + ) -> LogicalExpression { + LogicalExpression::Join(Join { left: left_group, right: right_group, expression, - })) + }) } } diff --git a/optd-mvp/src/expression/physical_expression.rs b/optd-mvp/src/expression/physical_expression.rs index 6552a96..5719752 100644 --- a/optd-mvp/src/expression/physical_expression.rs +++ b/optd-mvp/src/expression/physical_expression.rs @@ -1,35 +1,38 @@ //! Definition of physical expressions / operators in the Cascades query optimization framework. //! -//! FIXME: All fields are placeholders, and group IDs are just represented as i32 for now. +//! FIXME: All fields are placeholders. //! -//! TODO figure out if each operator should be in a different submodule. +//! TODO Remove dead code. +//! TODO Figure out if each operator should be in a different submodule. //! TODO This entire file is a WIP. -use crate::entities::*; +#![allow(dead_code)] + +use crate::{entities::*, memo::GroupId}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum PhysicalExpression { TableScan(TableScan), Filter(PhysicalFilter), HashJoin(HashJoin), } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct TableScan { table_schema: String, } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct PhysicalFilter { - child: i32, + child: GroupId, expression: String, } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct HashJoin { - left: i32, - right: i32, + left: GroupId, + right: GroupId, expression: String, } @@ -92,24 +95,28 @@ pub use build::*; #[cfg(test)] mod build { use super::*; - use crate::expression::Expression; + use crate::expression::PhysicalExpression; - pub fn table_scan(table_schema: String) -> Expression { - Expression::Physical(PhysicalExpression::TableScan(TableScan { table_schema })) + pub fn table_scan(table_schema: String) -> PhysicalExpression { + PhysicalExpression::TableScan(TableScan { table_schema }) } - pub fn filter(child_group: i32, expression: String) -> Expression { - Expression::Physical(PhysicalExpression::Filter(PhysicalFilter { + pub fn filter(child_group: GroupId, expression: String) -> PhysicalExpression { + PhysicalExpression::Filter(PhysicalFilter { child: child_group, expression, - })) + }) } - pub fn hash_join(left_group: i32, right_group: i32, expression: String) -> Expression { - Expression::Physical(PhysicalExpression::HashJoin(HashJoin { + pub fn hash_join( + left_group: GroupId, + right_group: GroupId, + expression: String, + ) -> PhysicalExpression { + PhysicalExpression::HashJoin(HashJoin { left: left_group, right: right_group, expression, - })) + }) } } diff --git a/optd-mvp/src/lib.rs b/optd-mvp/src/lib.rs index 98c5f11..506eee4 100644 --- a/optd-mvp/src/lib.rs +++ b/optd-mvp/src/lib.rs @@ -37,15 +37,3 @@ pub type OptimizerResult = Result; pub async fn migrate(db: &DatabaseConnection) -> Result<(), DbErr> { Migrator::refresh(db).await } - -/// Helper function for hashing expression data. -/// -/// TODO remove this. -fn hash_expression(kind: i16, data: &serde_json::Value) -> i64 { - use std::hash::{DefaultHasher, Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - kind.hash(&mut hasher); - data.hash(&mut hasher); - hasher.finish() as i64 -} diff --git a/optd-mvp/src/memo/interface.rs b/optd-mvp/src/memo/interface.rs deleted file mode 100644 index cb6c76d..0000000 --- a/optd-mvp/src/memo/interface.rs +++ /dev/null @@ -1,176 +0,0 @@ -//! This module defines the [`Memo`] trait, which defines shared behavior of all memo table that can -//! be used for query optimization in the Cascades framework. - -use crate::OptimizerResult; -use thiserror::Error; - -/// The different kinds of errors that might occur while running operations on a memo table. -#[derive(Error, Debug)] -pub enum MemoError { - #[error("unknown group ID {0}")] - UnknownGroup(i32), - #[error("unknown logical expression ID {0}")] - UnknownLogicalExpression(i32), - #[error("unknown physical expression ID {0}")] - UnknownPhysicalExpression(i32), - #[error("invalid expression encountered")] - InvalidExpression, -} - -/// A trait representing an implementation of a memoization table. -/// -/// Note that we use [`trait_variant`] here in order to add bounds on every method. -/// See this [blog post]( -/// https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#async-fn-in-public-traits) -/// for more information. -/// -/// TODO remove dead code. -#[allow(dead_code)] -#[trait_variant::make(Send)] -pub trait Memo { - /// A type representing a group in the Cascades framework. - type Group; - /// A type representing a unique identifier for a group. - type GroupId; - /// A type representing a logical expression. - type LogicalExpression; - /// A type representing a unique identifier for a logical expression. - type LogicalExpressionId; - /// A type representing a physical expression. - type PhysicalExpression; - /// A type representing a unique identifier for a physical expression. - type PhysicalExpressionId; - - /// Retrieves a [`Self::Group`] given a [`Self::GroupId`]. - /// - /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. - async fn get_group(&self, group_id: Self::GroupId) -> OptimizerResult; - - /// Retrieves a [`Self::LogicalExpression`] given a [`Self::LogicalExpressionId`]. - /// - /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] - /// error. - async fn get_logical_expression( - &self, - logical_expression_id: Self::LogicalExpressionId, - ) -> OptimizerResult; - - /// Retrieves a [`Self::PhysicalExpression`] given a [`Self::PhysicalExpressionId`]. - /// - /// If the physical expression does not exist, returns a - /// [`MemoError::UnknownPhysicalExpression`] error. - async fn get_physical_expression( - &self, - physical_expression_id: Self::PhysicalExpressionId, - ) -> OptimizerResult; - - /// Retrieves all of the logical expression "children" IDs of a group. - /// - /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. - async fn get_logical_children( - &self, - group_id: Self::GroupId, - ) -> OptimizerResult>; - - /// Retrieves all of the physical expression "children" IDs of a group. - /// - /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. - async fn get_physical_children( - &self, - group_id: Self::GroupId, - ) -> OptimizerResult>; - - /// Checks if a given logical expression is a duplicate / already exists in the memo table. - /// - /// In order to prevent a large amount of duplicate work, the memo table must support duplicate - /// expression detection. - /// - /// Returns `Some(expression_id)` if the memo table detects that the expression already exists, - /// and `None` otherwise. - async fn is_duplicate_logical_expression( - &self, - logical_expression: &Self::LogicalExpression, - ) -> OptimizerResult>; - - /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous - /// winner's physical expression ID. - /// - /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. - async fn update_group_winner( - &self, - group_id: Self::GroupId, - physical_expression_id: Self::PhysicalExpressionId, - ) -> OptimizerResult>; - - /// Adds a physical expression to an existing group via its [`Self::GroupId`]. - /// - /// The caller is required to pass in a slice of `GroupId` that represent the child groups of - /// the input expression. - /// - /// The caller is also required to set the `group_id` field of the input `physical_expression` - /// to be equal to `group_id`, otherwise this function will return a - /// [`MemoError::InvalidExpression`] error. - /// - /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. - /// - /// On successful insertion, returns the ID of the physical expression. - async fn add_physical_expression_to_group( - &self, - group_id: Self::GroupId, - physical_expression: Self::PhysicalExpression, - children: &[Self::GroupId], - ) -> OptimizerResult; - - /// Adds a logical expression to an existing group via its [`Self::GroupId`]. - /// - /// The caller is required to pass in a slice of `GroupId` that represent the child groups of - /// the input expression. - /// - /// The caller is also required to set the `group_id` field of the input `logical_expression` - /// to be equal to `group_id`, otherwise this function will return a - /// [`MemoError::InvalidExpression`] error. - /// - /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. - /// - /// If the memo table detects that the input logical expression is a duplicate expression, it - /// will **not** insert the expression into the memo table. Instead, it will return an - /// `Ok(Err(expression_id))`, which is a unique identifier of the expression that the input is a - /// duplicate of. The caller can use this ID to retrieve the group the original belongs to. - /// - /// If the memo table detects that the input is unique, it will insert the expression into the - /// input group and return an `Ok(Ok(expression_id))`. - async fn add_logical_expression_to_group( - &self, - group_id: Self::GroupId, - logical_expression: Self::LogicalExpression, - children: &[Self::GroupId], - ) -> OptimizerResult>; - - /// Adds a new logical expression into the memo table, creating a new group if the expression - /// does not already exist. - /// - /// The caller is required to pass in a slice of `GroupId` that represent the child groups of - /// the input expression. - /// - /// The [`Self::LogicalExpression`] type should have some sort of mechanism for checking if - /// the expression has been seen before, and if it has already been created, then the parent - /// group ID should also be retrievable. - /// - /// If the expression already exists, then this function will return the [`Self::GroupId`] of - /// the parent group and the corresponding (already existing) [`Self::LogicalExpressionId`]. It - /// will also completely ignore the group ID field of the input expression as well as ignore the - /// input slice of child groups. - /// - /// If the expression does not exist, this function will create a new group and a new - /// expression, returning brand new IDs for both. - async fn add_logical_expression( - &self, - expression: Self::LogicalExpression, - children: &[Self::LogicalExpressionId], - ) -> OptimizerResult< - Result< - (Self::GroupId, Self::LogicalExpressionId), - (Self::GroupId, Self::LogicalExpressionId), - >, - >; -} diff --git a/optd-mvp/src/memo/mod.rs b/optd-mvp/src/memo/mod.rs index 5253352..fbf23a2 100644 --- a/optd-mvp/src/memo/mod.rs +++ b/optd-mvp/src/memo/mod.rs @@ -3,7 +3,33 @@ //! //! TODO more docs. -mod persistent; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +/// A new type of an integer identifying a unique group. +#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)] +#[serde(transparent)] +pub struct GroupId(pub i32); + +/// A new type of an integer identifying a unique logical expression. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct LogicalExpressionId(pub i32); -mod interface; -pub use interface::{Memo, MemoError}; +/// A new type of an integer identifying a unique physical expression. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct PhysicalExpressionId(pub i32); + +/// The different kinds of errors that might occur while running operations on a memo table. +#[derive(Error, Debug)] +pub enum MemoError { + #[error("unknown group ID {0:?}")] + UnknownGroup(GroupId), + #[error("unknown logical expression ID {0:?}")] + UnknownLogicalExpression(LogicalExpressionId), + #[error("unknown physical expression ID {0:?}")] + UnknownPhysicalExpression(PhysicalExpressionId), + #[error("invalid expression encountered")] + InvalidExpression, +} + +mod persistent; diff --git a/optd-mvp/src/memo/persistent/implementation.rs b/optd-mvp/src/memo/persistent/implementation.rs index 4c06c4e..4fc7048 100644 --- a/optd-mvp/src/memo/persistent/implementation.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -1,228 +1,370 @@ -//! This module contains the implementation of the [`Memo`] trait for [`PersistentMemo`]. +//! This module contains the implementation of [`PersistentMemo`]. +//! +//! TODO For parallelism, almost all of these methods need to be under transactions. +//! TODO Write more docs. +//! TODO Remove dead code. -use super::*; +#![allow(dead_code)] + +use super::PersistentMemo; use crate::{ - hash_expression, - memo::{Memo, MemoError}, - OptimizerResult, + entities::*, + expression::{LogicalExpression, PhysicalExpression}, + memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId}, + OptimizerResult, DATABASE_URL, +}; +use sea_orm::{ + entity::prelude::*, + entity::{IntoActiveModel, NotSet, Set}, + Database, }; -impl Memo for PersistentMemo { - type Group = cascades_group::Model; - type GroupId = i32; - type LogicalExpression = logical_expression::Model; - type LogicalExpressionId = i32; - type PhysicalExpression = physical_expression::Model; - type PhysicalExpressionId = i32; +impl PersistentMemo { + /// Creates a new `PersistentMemo` struct by connecting to a database defined at + /// [`DATABASE_URL`]. + pub async fn new() -> Self { + Self { + db: Database::connect(DATABASE_URL).await.unwrap(), + } + } + + /// Deletes all objects in the backing database. + /// + /// Since there is no asynchronous drop yet in Rust, in order to drop all objects in the + /// database, the user must call this manually. + pub async fn cleanup(&self) { + macro_rules! delete_all { + ($($module: ident),+ $(,)?) => { + $( + $module::Entity::delete_many() + .exec(&self.db) + .await + .unwrap(); + )+ + }; + } + + delete_all! { + cascades_group, + fingerprint, + logical_expression, + logical_children, + physical_expression, + physical_children + }; + } - async fn get_group(&self, group_id: Self::GroupId) -> OptimizerResult { - Ok(CascadesGroup::find_by_id(group_id) + /// Retrieves a [`cascades_group::Model`] given its ID. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// FIXME: use an in-memory representation of a group instead. + pub async fn get_group(&self, group_id: GroupId) -> OptimizerResult { + Ok(cascades_group::Entity::find_by_id(group_id.0) .one(&self.db) .await? .ok_or(MemoError::UnknownGroup(group_id))?) } - async fn get_logical_expression( + /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`]. + /// + /// If the physical expression does not exist, returns a + /// [`MemoError::UnknownPhysicalExpression`] error. + pub async fn get_physical_expression( &self, - logical_expression_id: Self::LogicalExpressionId, - ) -> OptimizerResult { - Ok(LogicalExpression::find_by_id(logical_expression_id) + physical_expression_id: PhysicalExpressionId, + ) -> OptimizerResult<(GroupId, PhysicalExpression)> { + // Lookup the entity in the database via the unique expression ID. + let model = physical_expression::Entity::find_by_id(physical_expression_id.0) .one(&self.db) .await? - .ok_or(MemoError::UnknownLogicalExpression(logical_expression_id))?) + .ok_or(MemoError::UnknownPhysicalExpression(physical_expression_id))?; + + let group_id = GroupId(model.group_id); + let expr = model.into(); + + Ok((group_id, expr)) } - async fn get_physical_expression( + /// Retrieves a [`logical_expression::Model`] given its [`LogicalExpressionId`]. + /// + /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] + /// error. + pub async fn get_logical_expression( &self, - physical_expression_id: Self::PhysicalExpressionId, - ) -> OptimizerResult { - Ok(PhysicalExpression::find_by_id(physical_expression_id) + logical_expression_id: LogicalExpressionId, + ) -> OptimizerResult<(GroupId, LogicalExpression)> { + // Lookup the entity in the database via the unique expression ID. + let model = logical_expression::Entity::find_by_id(logical_expression_id.0) .one(&self.db) .await? - .ok_or(MemoError::UnknownPhysicalExpression(physical_expression_id))?) + .ok_or(MemoError::UnknownLogicalExpression(logical_expression_id))?; + + let group_id = GroupId(model.group_id); + let expr = model.into(); + + Ok((group_id, expr)) } - async fn get_logical_children( + /// Retrieves all of the logical expression "children" IDs of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// FIXME: `find_related` does not work for some reason, have to use manual `filter`. + pub async fn get_logical_children( &self, - group_id: Self::GroupId, - ) -> OptimizerResult> { - // First retrieve the group record, and then find all related logical expressions. - Ok(self - .get_group(group_id) - .await? - .find_related(LogicalChildren) + group_id: GroupId, + ) -> OptimizerResult> { + // Search for expressions that have the given parent group ID. + let children = logical_expression::Entity::find() + .filter(logical_expression::Column::GroupId.eq(group_id.0)) .all(&self.db) .await? .into_iter() - .map(|m| m.logical_expression_id) - .collect()) + .map(|m| LogicalExpressionId(m.id)) + .collect(); + + Ok(children) } - async fn get_physical_children( + /// Retrieves all of the physical expression "children" IDs of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn get_physical_children( &self, - group_id: Self::GroupId, - ) -> OptimizerResult> { - // First retrieve the group record, and then find all related physical expressions. - Ok(self - .get_group(group_id) - .await? - .find_related(PhysicalChildren) + group_id: GroupId, + ) -> OptimizerResult> { + // Search for expressions that have the given parent group ID. + let children = physical_expression::Entity::find() + .filter(physical_expression::Column::GroupId.eq(group_id.0)) .all(&self.db) .await? .into_iter() - .map(|m| m.physical_expression_id) - .collect()) - } - - /// FIXME Check that all of the children are root groups? - async fn is_duplicate_logical_expression( - &self, - logical_expression: &Self::LogicalExpression, - ) -> OptimizerResult> { - // Lookup all expressions that have the same fingerprint and kind. There may be false - // positives, but we will check for those next. - let kind = logical_expression.kind; - let fingerprint = hash_expression(kind, &logical_expression.data); - - let potential_matches = Fingerprint::find() - .filter(fingerprint::Column::Hash.eq(fingerprint)) - .filter(fingerprint::Column::Kind.eq(kind)) - .all(&self.db) - .await?; - - if potential_matches.is_empty() { - return Ok(None); - } - - let mut match_id = None; - for potential_match in potential_matches { - let expr_id = potential_match.logical_expression_id; - let expr = self.get_logical_expression(expr_id).await?; - - if expr.data == logical_expression.data { - // There should be at most one duplicate expression. - match_id = Some(expr_id); - break; - } - } + .map(|m| PhysicalExpressionId(m.id)) + .collect(); - Ok(match_id) + Ok(children) } + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous + /// winner's physical expression ID. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// /// FIXME: In the future, this should first check that we aren't overwriting a winner that was - /// updated from another thread. - async fn update_group_winner( + /// updated from another thread by comparing against the cost of the plan. + pub async fn update_group_winner( &self, - group_id: Self::GroupId, - physical_expression_id: Self::PhysicalExpressionId, - ) -> OptimizerResult> { - // First retrieve the group record, and then use an `ActiveModel` to update it. + group_id: GroupId, + physical_expression_id: PhysicalExpressionId, + ) -> OptimizerResult> { + // First retrieve the group record. let mut group = self.get_group(group_id).await?.into_active_model(); - let old_id = group.winner; - group.winner = Set(Some(physical_expression_id)); + // Update the group to point to the new winner. + let old_id = group.winner; + group.winner = Set(Some(physical_expression_id.0)); group.update(&self.db).await?; - // The old value must be set (`None` still means it has been set). - let old = old_id.unwrap(); + // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. + let old = old_id.unwrap().map(PhysicalExpressionId); Ok(old) } - async fn add_physical_expression_to_group( + /// Adds a logical expression to an existing group via its ID. + /// + /// The caller is required to pass in a slice of [`GroupId`] that represent the child groups of + /// the input expression. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// If the memo table detects that the input logical expression is a duplicate expression, this + /// function will **not** insert the expression into the memo table. Instead, it will return an + /// `Ok(Err(expression_id))`, which is a unique identifier of the expression that the input is a + /// duplicate of. The caller can use this ID to retrieve the group the original belongs to. + /// + /// If the memo table detects that the input is unique, it will insert the expression into the + /// input group and return an `Ok(Ok(expression_id))`. + /// + /// FIXME Check that all of the children are reduced groups? + pub async fn add_logical_expression_to_group( &self, - group_id: Self::GroupId, - physical_expression: Self::PhysicalExpression, - children: &[Self::GroupId], - ) -> OptimizerResult { - if physical_expression.group_id != group_id { - Err(MemoError::InvalidExpression)? + group_id: GroupId, + logical_expression: LogicalExpression, + children: &[GroupId], + ) -> OptimizerResult> { + // Check if the expression already exists anywhere in the memo table. + if let Some(existing_id) = self + .is_duplicate_logical_expression(&logical_expression) + .await? + { + return Ok(Err(existing_id)); } // Check if the group actually exists. let _ = self.get_group(group_id).await?; + // Insert the expression. + let model: logical_expression::Model = logical_expression.into(); + let mut active_model = model.into_active_model(); + active_model.group_id = Set(group_id.0); + active_model.id = NotSet; + let new_model = active_model.insert(&self.db).await?; + + let expr_id = new_model.id; + // Insert the child groups of the expression into the junction / children table. - if !children.is_empty() { - PhysicalChildren::insert_many(children.iter().copied().map(|group_id| { - physical_children::ActiveModel { - physical_expression_id: Set(physical_expression.id), - group_id: Set(group_id), - } - })) - .exec(&self.db) - .await?; - } + logical_children::Entity::insert_many(children.iter().copied().map(|child_id| { + logical_children::ActiveModel { + logical_expression_id: Set(expr_id), + group_id: Set(child_id.0), + } + })) + .on_empty_do_nothing() + .exec(&self.db) + .await?; - // Insert the expression. - let res = physical_expression - .into_active_model() - .insert(&self.db) + // Finally, insert the fingerprint of the logical expression as well. + let new_expr: LogicalExpression = new_model.into(); + let kind = new_expr.kind(); + let hash = new_expr.fingerprint(); + + let fingerprint = fingerprint::ActiveModel { + id: NotSet, + logical_expression_id: Set(expr_id), + kind: Set(kind), + hash: Set(hash), + }; + let _ = fingerprint::Entity::insert(fingerprint) + .exec(&self.db) .await?; - Ok(res.id) + Ok(Ok(LogicalExpressionId(expr_id))) } - /// FIXME Check that all of the children are reduced groups? - async fn add_logical_expression_to_group( + /// Adds a physical expression to an existing group via its ID. + /// + /// The caller is required to pass in a slice of [`GroupId`] that represent the child groups of + /// the input expression. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// On successful insertion, returns the ID of the physical expression. + pub async fn add_physical_expression_to_group( &self, - group_id: Self::GroupId, - logical_expression: Self::LogicalExpression, - children: &[Self::GroupId], - ) -> OptimizerResult> { - if logical_expression.group_id != group_id { - Err(MemoError::InvalidExpression)? - } - - // Check if the expression already exists in the memo table. - if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) - .await? - { - return Ok(Err(existing_id)); - } - + group_id: GroupId, + physical_expression: PhysicalExpression, + children: &[GroupId], + ) -> OptimizerResult { // Check if the group actually exists. let _ = self.get_group(group_id).await?; + // Insert the expression. + let model: physical_expression::Model = physical_expression.into(); + let mut active_model = model.into_active_model(); + active_model.group_id = Set(group_id.0); + active_model.id = NotSet; + let new_model = active_model.insert(&self.db).await?; + // Insert the child groups of the expression into the junction / children table. - if !children.is_empty() { - LogicalChildren::insert_many(children.iter().copied().map(|group_id| { - logical_children::ActiveModel { - logical_expression_id: Set(logical_expression.id), - group_id: Set(group_id), - } - })) - .exec(&self.db) + physical_children::Entity::insert_many(children.iter().copied().map(|child_id| { + physical_children::ActiveModel { + physical_expression_id: Set(new_model.id), + group_id: Set(child_id.0), + } + })) + .on_empty_do_nothing() + .exec(&self.db) + .await?; + + Ok(PhysicalExpressionId(new_model.id)) + } + + /// Checks if the given logical expression is a duplicate / already exists in the memo table. + /// + /// In order to prevent a large amount of duplicate work, the memo table must support duplicate + /// expression detection. + /// + /// Returns `Some(expression_id)` if the memo table detects that the expression already exists, + /// and `None` otherwise. + /// + /// This function assumes that the child groups of the expression are currently roots of their + /// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input + /// expression should _not_ have G2 as a child, and should be replaced with G1. + /// + /// TODO Check that all of the children are root groups? How to do this? + pub async fn is_duplicate_logical_expression( + &self, + logical_expression: &LogicalExpression, + ) -> OptimizerResult> { + let model: logical_expression::Model = logical_expression.clone().into(); + + // Lookup all expressions that have the same fingerprint and kind. There may be false + // positives, but we will check for those next. + let kind = model.kind; + let fingerprint = logical_expression.fingerprint(); + + // Filter first by the fingerprint, and then the kind. + // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the + // second filter? + let potential_matches = fingerprint::Entity::find() + .filter(fingerprint::Column::Hash.eq(fingerprint)) + .filter(fingerprint::Column::Kind.eq(kind)) + .all(&self.db) .await?; + + if potential_matches.is_empty() { + return Ok(None); } - // Insert the expression. - let res = logical_expression - .into_active_model() - .insert(&self.db) - .await?; + // Now that we have all of the expressions that match the given fingerprint, we need to + // filter out all of the expressions that might have had the same fingerprint but are not + // actually equivalent (hash collisions). + let mut match_id = None; + for potential_match in potential_matches { + let expr_id = LogicalExpressionId(potential_match.logical_expression_id); + let (_, expr) = self.get_logical_expression(expr_id).await?; - Ok(Ok(res.id)) + // Check for an exact match. + if &expr == logical_expression { + match_id = Some(expr_id); + + // There should be at most one duplicate expression, so we can break here. + break; + } + } + + Ok(match_id) } + /// Adds a new group into the memo table via a logical expression, creating a new group if the + /// logical expression does not already exist. + /// + /// The caller is required to pass in a slice of [`GroupId`] that represent the child groups of + /// the input expression. + /// + /// If the expression already exists, then this function will return the [`GroupId`] of the + /// parent group and the corresponding (already existing) [`LogicalExpressionId`]. It will also + /// completely ignore the group ID field of the input expression as well as ignore the input + /// slice of child groups. + /// + /// If the expression does not exist, this function will create a new group and a new + /// expression, returning brand new IDs for both. + /// /// FIXME Check that all of the children are reduced groups? - async fn add_logical_expression( + pub async fn add_group( &self, - logical_expression: Self::LogicalExpression, - children: &[Self::GroupId], - ) -> OptimizerResult< - Result< - (Self::GroupId, Self::LogicalExpressionId), - (Self::GroupId, Self::LogicalExpressionId), - >, - > { + logical_expression: LogicalExpression, + children: &[GroupId], + ) -> OptimizerResult> + { // Check if the expression already exists in the memo table. if let Some(existing_id) = self .is_duplicate_logical_expression(&logical_expression) .await? { - let expr = self.get_logical_expression(existing_id).await?; - return Ok(Err((expr.group_id, expr.id))); + let (group_id, _expr) = self.get_logical_expression(existing_id).await?; + return Ok(Err((group_id, existing_id))); } // The expression does not exist yet, so we need to create a new group and new expression. @@ -232,39 +374,45 @@ impl Memo for PersistentMemo { ..Default::default() }; - // Create a new group. + // Create the new group. let res = cascades_group::Entity::insert(group).exec(&self.db).await?; - // Insert the input expression with the correct `group_id`. - let mut new_expr = logical_expression.into_active_model(); - new_expr.group_id = Set(res.last_insert_id); - new_expr.id = NotSet; - let new_expr = new_expr.insert(&self.db).await?; + // Insert the input expression into the newly created group. + let model: logical_expression::Model = logical_expression.clone().into(); + let mut active_model = model.into_active_model(); + active_model.group_id = Set(res.last_insert_id); + active_model.id = NotSet; + let new_model = active_model.insert(&self.db).await?; + + let group_id = new_model.group_id; + let expr_id = new_model.id; // Insert the child groups of the expression into the junction / children table. - if !children.is_empty() { - LogicalChildren::insert_many(children.iter().copied().map(|group_id| { - logical_children::ActiveModel { - logical_expression_id: Set(new_expr.id), - group_id: Set(group_id), - } - })) - .exec(&self.db) - .await?; - } + logical_children::Entity::insert_many(children.iter().copied().map(|child_id| { + logical_children::ActiveModel { + logical_expression_id: Set(new_model.id), + group_id: Set(child_id.0), + } + })) + .on_empty_do_nothing() + .exec(&self.db) + .await?; + + // Finally, insert the fingerprint of the logical expression as well. + let new_expr: LogicalExpression = new_model.into(); + let kind = new_expr.kind(); + let hash = new_expr.fingerprint(); - // Insert the fingerprint of the logical expression. - let hash = hash_expression(new_expr.kind, &new_expr.data); let fingerprint = fingerprint::ActiveModel { id: NotSet, - logical_expression_id: Set(new_expr.id), - kind: Set(new_expr.kind), + logical_expression_id: Set(expr_id), + kind: Set(kind), hash: Set(hash), }; let _ = fingerprint::Entity::insert(fingerprint) .exec(&self.db) .await?; - Ok(Ok((new_expr.group_id, new_expr.id))) + Ok(Ok((GroupId(group_id), LogicalExpressionId(expr_id)))) } } diff --git a/optd-mvp/src/memo/persistent/mod.rs b/optd-mvp/src/memo/persistent/mod.rs index ae2577a..ed64fc5 100644 --- a/optd-mvp/src/memo/persistent/mod.rs +++ b/optd-mvp/src/memo/persistent/mod.rs @@ -1,11 +1,7 @@ //! This module contains the definition and implementation of the [`PersistentMemo`] type, which //! implements the `Memo` trait and supports memo table operations necessary for query optimization. -use crate::{ - entities::{prelude::*, *}, - DATABASE_URL, -}; -use sea_orm::*; +use sea_orm::DatabaseConnection; #[cfg(test)] mod tests; @@ -19,48 +15,4 @@ pub struct PersistentMemo { db: DatabaseConnection, } -impl PersistentMemo { - /// Creates a new `PersistentMemo` struct by connecting to a database defined at - /// [`DATABASE_URL`]. - /// - /// TODO remove dead code and write docs. - #[allow(dead_code)] - pub async fn new() -> Self { - Self { - db: Database::connect(DATABASE_URL).await.unwrap(), - } - } - - /// Since there is no asynchronous drop yet in Rust, we must do this manually. - /// - /// TODO remove dead code and write docs. - #[allow(dead_code)] - pub async fn cleanup(&self) { - cascades_group::Entity::delete_many() - .exec(&self.db) - .await - .unwrap(); - fingerprint::Entity::delete_many() - .exec(&self.db) - .await - .unwrap(); - logical_expression::Entity::delete_many() - .exec(&self.db) - .await - .unwrap(); - logical_children::Entity::delete_many() - .exec(&self.db) - .await - .unwrap(); - physical_expression::Entity::delete_many() - .exec(&self.db) - .await - .unwrap(); - physical_children::Entity::delete_many() - .exec(&self.db) - .await - .unwrap(); - } -} - mod implementation; diff --git a/optd-mvp/src/memo/persistent/tests.rs b/optd-mvp/src/memo/persistent/tests.rs index 7158b30..f3afea6 100644 --- a/optd-mvp/src/memo/persistent/tests.rs +++ b/optd-mvp/src/memo/persistent/tests.rs @@ -1,36 +1,142 @@ -use super::*; -use crate::{expression::*, memo::Memo}; +use crate::{expression::*, memo::persistent::PersistentMemo}; -/// Tests is exact expression matches are detected and handled by the memo table. +/// Tests that exact expression matches are detected and handled by the memo table. #[ignore] #[tokio::test] -async fn test_simple_duplicates() { +async fn test_simple_logical_duplicates() { let memo = PersistentMemo::new().await; memo.cleanup().await; - let scan = scan("(a int, b int)".to_string()); - let scan1 = scan.clone(); - let scan2 = scan.clone(); + let scan = scan("t1".to_string()); + let scan1a = scan.clone(); + let scan1b = scan.clone(); + let scan2a = scan.clone(); + let scan2b = scan.clone(); - let res0 = memo - .add_logical_expression(scan.into(), &[]) + // Insert a new group and its corresponding expression. + let (group_id, logical_expression_id) = memo.add_group(scan, &[]).await.unwrap().ok().unwrap(); + + // Test `add_logical_expression`. + { + // Attempting to create a new group with a duplicate expression should fail every time. + let (group_id_1a, logical_expression_id_1a) = + memo.add_group(scan1a, &[]).await.unwrap().err().unwrap(); + assert_eq!(group_id, group_id_1a); + assert_eq!(logical_expression_id, logical_expression_id_1a); + + // Try again just in case... + let (group_id_1b, logical_expression_id_1b) = + memo.add_group(scan1b, &[]).await.unwrap().err().unwrap(); + assert_eq!(group_id, group_id_1b); + assert_eq!(logical_expression_id, logical_expression_id_1b); + } + + // Test `add_logical_expression_to_group`. + { + // Attempting to add a duplicate expression into the same group should also fail every time. + let logical_expression_id_2a = memo + .add_logical_expression_to_group(group_id, scan2a, &[]) + .await + .unwrap() + .err() + .unwrap(); + assert_eq!(logical_expression_id, logical_expression_id_2a); + + let logical_expression_id_2b = memo + .add_logical_expression_to_group(group_id, scan2b, &[]) + .await + .unwrap() + .err() + .unwrap(); + assert_eq!(logical_expression_id, logical_expression_id_2b); + } + + memo.cleanup().await; +} + +/// Tests that physical expression are _not_ subject to duplicate detection and elimination. +/// +/// !!! Important !!! Note that this behavior should not actually be seen during query optimization, +/// since if logical expression have been deduplicated, there should not be any duplicate physical +/// expressions as they are derivative of the deduplicated logical expressions. +#[ignore] +#[tokio::test] +async fn test_simple_add_physical_expression() { + let memo = PersistentMemo::new().await; + memo.cleanup().await; + + // Insert a new group and its corresponding expression. + let scan = scan("t1".to_string()); + let (group_id, _) = memo.add_group(scan, &[]).await.unwrap().ok().unwrap(); + + // Insert two identical physical expressions into the _same_ group. + let table_scan_1 = table_scan("t1".to_string()); + let table_scan_2 = table_scan_1.clone(); + + let physical_expression_id_1 = memo + .add_physical_expression_to_group(group_id, table_scan_1, &[]) .await - .unwrap() - .ok(); - let res1 = memo - .add_logical_expression(scan1.into(), &[]) + .unwrap(); + + let physical_expression_id_2 = memo + .add_physical_expression_to_group(group_id, table_scan_2, &[]) + .await + .unwrap(); + + // Since physical expressions do not need duplicate detection, + assert_ne!(physical_expression_id_1, physical_expression_id_2); + + memo.cleanup().await; +} + +/// Tests if the memo tables able to correctly retrieve a group's expressions. +#[ignore] +#[tokio::test] +async fn test_simple_tree() { + let memo = PersistentMemo::new().await; + memo.cleanup().await; + + // Create two scan groups. + let scan1: LogicalExpression = scan("t1".to_string()); + let scan2 = scan("t2".to_string()); + let (scan_id_1, scan_expr_id_1) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap(); + let (scan_id_2, scan_expr_id_2) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap(); + + assert_eq!( + memo.get_logical_children(scan_id_1).await.unwrap(), + &[scan_expr_id_1] + ); + assert_eq!( + memo.get_logical_children(scan_id_2).await.unwrap(), + &[scan_expr_id_2] + ); + + // Create two join expression that should be in the same group. + // TODO: Eventually, the predicates will be in their own table, and the predicate representation + // will be a foreign key. For now, we represent them as strings. + let join1 = join(scan_id_1, scan_id_2, "t1.a = t2.b".to_string()); + let join2 = join(scan_id_2, scan_id_1, "t1.a = t2.b".to_string()); + + // Create the group, adding the first expression. + let (join_id, join_expr_id_1) = memo + .add_group(join1, &[scan_id_1, scan_id_2]) .await .unwrap() - .err(); - let res2 = memo - .add_logical_expression(scan2.into(), &[]) + .ok() + .unwrap(); + // Add the second expression. + let join_expr_id_2 = memo + .add_logical_expression_to_group(join_id, join2, &[scan_id_2, scan_id_1]) .await .unwrap() - .err(); + .ok() + .unwrap(); - assert_eq!(res0, res1); - assert_eq!(res0, res2); - assert_eq!(res1, res2); + assert_ne!(join_expr_id_1, join_expr_id_2); + assert_eq!( + memo.get_logical_children(join_id).await.unwrap(), + &[join_expr_id_1, join_expr_id_2] + ); memo.cleanup().await; } diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_logical_children.rs b/optd-mvp/src/migrator/memo/m20241127_000001_logical_children.rs index d0835f4..037a637 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_logical_children.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_logical_children.rs @@ -40,7 +40,7 @@ impl MigrationTrait for Migration { ) .foreign_key( ForeignKey::create() - .from(LogicalChildren::Table, LogicalChildren::GroupId) + .from(LogicalChildren::Table, LogicalChildren::LogicalExpressionId) .to(LogicalExpression::Table, LogicalExpression::Id) .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade),