From f0e4b45bfff7e75fa6ca46d18bb2f2bdce9354e4 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 27 Nov 2024 19:21:48 -0500 Subject: [PATCH] add memo trait interface and persistent memo implementation This commit adds a first draft of a memo table trait and a persistent memo table implementation backed by SeaORM entities. --- optd-mvp/src/lib.rs | 21 +++ optd-mvp/src/memo/interface.rs | 139 ++++++++++++++++++++ optd-mvp/src/memo/mod.rs | 9 ++ optd-mvp/src/memo/persistent.rs | 218 ++++++++++++++++++++++++++++++++ 4 files changed, 387 insertions(+) create mode 100644 optd-mvp/src/memo/interface.rs create mode 100644 optd-mvp/src/memo/mod.rs create mode 100644 optd-mvp/src/memo/persistent.rs diff --git a/optd-mvp/src/lib.rs b/optd-mvp/src/lib.rs index 5abd59f..c5185cd 100644 --- a/optd-mvp/src/lib.rs +++ b/optd-mvp/src/lib.rs @@ -1,16 +1,37 @@ use sea_orm::*; use sea_orm_migration::prelude::*; +use thiserror::Error; mod migrator; use migrator::Migrator; mod entities; +mod memo; +use memo::MemoError; + /// The filename of the SQLite database for migration. pub const DATABASE_FILENAME: &str = "sqlite.db"; /// The URL of the SQLite database for migration. pub const DATABASE_URL: &str = "sqlite:./sqlite.db?mode=rwc"; +/// An error type wrapping all the different kinds of error the optimizer might raise. +/// +/// TODO more docs. +#[derive(Error, Debug)] +pub enum OptimizerError { + #[error("SeaORM error")] + Database(#[from] sea_orm::error::DbErr), + #[error("Memo table logical error")] + Memo(#[from] MemoError), + #[error("unknown error")] + Unknown, +} + +/// Shorthand for a [`Result`] with an error type [`OptimizerError`]. +pub type OptimizerResult = Result; + +/// Applies all migrations. pub async fn migrate(db: &DatabaseConnection) -> Result<(), DbErr> { Migrator::refresh(db).await } diff --git a/optd-mvp/src/memo/interface.rs b/optd-mvp/src/memo/interface.rs new file mode 100644 index 0000000..5fc0e65 --- /dev/null +++ b/optd-mvp/src/memo/interface.rs @@ -0,0 +1,139 @@ +use crate::OptimizerResult; +use thiserror::Error; + +#[derive(Error, Debug)] +/// The different kinds of errors that might occur while running operations on a memo table. +pub enum MemoError { + #[error("unknown group ID {0}")] + UnknownGroup(i32), + #[error("unknown logical expression ID {0}")] + UnknownLogicalExpression(i32), + #[error("unknown physical expression ID {0}")] + UnknownPhysicalExpression(i32), + #[error("invalid expression encountered")] + InvalidExpression, +} + +/// A trait representing an implementation of a memoization table. +/// +/// Note that we use [`trait_variant`] here in order to add bounds on every method. +/// See this [blog post]( +/// https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#async-fn-in-public-traits) +/// for more information. +#[allow(dead_code)] +#[trait_variant::make(Send)] +pub trait Memo { + /// A type representing a group in the Cascades framework. + type Group; + /// A type representing a unique identifier for a group. + type GroupId; + /// A type representing a logical expression. + type LogicalExpression; + /// A type representing a unique identifier for a logical expression. + type LogicalExpressionId; + /// A type representing a physical expression. + type PhysicalExpression; + /// A type representing a unique identifier for a physical expression. + type PhysicalExpressionId; + + /// Retrieves a [`Self::Group`] given a [`Self::GroupId`]. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_group(&self, group_id: Self::GroupId) -> OptimizerResult; + + /// Retrieves a [`Self::LogicalExpression`] given a [`Self::LogicalExpressionId`]. + /// + /// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`] + /// error. + async fn get_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> OptimizerResult; + + /// Retrieves a [`Self::PhysicalExpression`] given a [`Self::PhysicalExpressionId`]. + /// + /// If the physical expression does not exist, returns a + /// [`MemoError::UnknownPhysicalExpression`] error. + async fn get_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> OptimizerResult; + + /// Retrieves all of the logical expression "children" IDs of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_logical_children( + &self, + group_id: Self::GroupId, + ) -> OptimizerResult>; + + /// Retrieves all of the physical expression "children" IDs of a group. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn get_physical_children( + &self, + group_id: Self::GroupId, + ) -> OptimizerResult>; + + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous + /// winner's physical expression ID. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn update_group_winner( + &self, + group_id: Self::GroupId, + physical_expression_id: Self::PhysicalExpressionId, + ) -> OptimizerResult>; + + /// Adds a logical expression to an existing group via its [`Self::GroupId`]. + /// + /// The caller is required to pass in a vector of `GroupId` that represent the child groups of + /// the input expression. + /// + /// The caller is also required to set the `group_id` field of the input `logical_expression` + /// to be equal to `group_id`, otherwise this function will return a + /// [`MemoError::InvalidExpression`] error. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn add_logical_expression_to_group( + &self, + group_id: Self::GroupId, + logical_expression: Self::LogicalExpression, + children: Vec, + ) -> OptimizerResult<()>; + + /// Adds a physical expression to an existing group via its [`Self::GroupId`]. + /// + /// The caller is required to pass in a vector of `GroupId` that represent the child groups of + /// the input expression. + /// + /// The caller is also required to set the `group_id` field of the input `physical_expression` + /// to be equal to `group_id`, otherwise this function will return a + /// [`MemoError::InvalidExpression`] error. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + async fn add_physical_expression_to_group( + &self, + group_id: Self::GroupId, + physical_expression: Self::PhysicalExpression, + children: Vec, + ) -> OptimizerResult<()>; + + /// Adds a new logical expression into the memo table, creating a new group if the expression + /// does not already exist. + /// + /// The [`Self::LogicalExpression`] type should have some sort of mechanism for checking if + /// the expression has been seen before, and if it has already been created, then the parent + /// group ID should also be retrievable. + /// + /// If the expression already exists, then this function will return the [`Self::GroupId`] of + /// the parent group and the corresponding (already existing) [`Self::LogicalExpressionId`]. + /// + /// If the expression does not exist, this function will create a new group and a new + /// expression, returning brand new IDs for both. + async fn add_logical_expression( + &self, + expression: Self::LogicalExpression, + children: Vec, + ) -> OptimizerResult<(Self::GroupId, Self::LogicalExpressionId)>; +} diff --git a/optd-mvp/src/memo/mod.rs b/optd-mvp/src/memo/mod.rs new file mode 100644 index 0000000..5253352 --- /dev/null +++ b/optd-mvp/src/memo/mod.rs @@ -0,0 +1,9 @@ +//! This module contains items related to the memo table, which is key to the Cascades query +//! optimization framework. +//! +//! TODO more docs. + +mod persistent; + +mod interface; +pub use interface::{Memo, MemoError}; diff --git a/optd-mvp/src/memo/persistent.rs b/optd-mvp/src/memo/persistent.rs new file mode 100644 index 0000000..6f48873 --- /dev/null +++ b/optd-mvp/src/memo/persistent.rs @@ -0,0 +1,218 @@ +use crate::{ + entities::{prelude::*, *}, + memo::{Memo, MemoError}, + OptimizerResult, +}; +use sea_orm::*; + +/// A persistent memo table, backed by a database on disk. +/// +/// TODO more docs. +pub struct PersistentMemo { + /// This `PersistentMemo` is reliant on the SeaORM [`DatabaseConnection`] that stores all of the + /// objects needed for query optimization. + db: DatabaseConnection, +} + +impl Memo for PersistentMemo { + type Group = cascades_group::Model; + type GroupId = i32; + type LogicalExpression = logical_expression::Model; + type LogicalExpressionId = i32; + type PhysicalExpression = physical_expression::Model; + type PhysicalExpressionId = i32; + + async fn get_group(&self, group_id: Self::GroupId) -> OptimizerResult { + Ok(CascadesGroup::find_by_id(group_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownGroup(group_id))?) + } + + async fn get_logical_expression( + &self, + logical_expression_id: Self::LogicalExpressionId, + ) -> OptimizerResult { + Ok(LogicalExpression::find_by_id(logical_expression_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownLogicalExpression(logical_expression_id))?) + } + + async fn get_physical_expression( + &self, + physical_expression_id: Self::PhysicalExpressionId, + ) -> OptimizerResult { + Ok(PhysicalExpression::find_by_id(physical_expression_id) + .one(&self.db) + .await? + .ok_or(MemoError::UnknownPhysicalExpression(physical_expression_id))?) + } + + async fn get_logical_children( + &self, + group_id: Self::GroupId, + ) -> OptimizerResult> { + // First retrieve the group record, and then find all related logical expressions. + Ok(self + .get_group(group_id) + .await? + .find_related(LogicalChildren) + .all(&self.db) + .await? + .into_iter() + .map(|m| m.logical_expression_id) + .collect()) + } + + async fn get_physical_children( + &self, + group_id: Self::GroupId, + ) -> OptimizerResult> { + // First retrieve the group record, and then find all related physical expressions. + Ok(self + .get_group(group_id) + .await? + .find_related(PhysicalChildren) + .all(&self.db) + .await? + .into_iter() + .map(|m| m.physical_expression_id) + .collect()) + } + + /// FIXME: In the future, this should first check that we aren't overwriting a winner that was + /// updated from another thread. + async fn update_group_winner( + &self, + group_id: Self::GroupId, + physical_expression_id: Self::PhysicalExpressionId, + ) -> OptimizerResult> { + // First retrieve the group record, and then use an `ActiveModel` to update it. + let mut group = self.get_group(group_id).await?.into_active_model(); + let old_id = group.winner; + + group.winner = Set(Some(physical_expression_id)); + group.update(&self.db).await?; + + // The old value must be set (`None` still means it has been set). + let old = old_id.unwrap(); + Ok(old) + } + + async fn add_logical_expression_to_group( + &self, + group_id: Self::GroupId, + logical_expression: Self::LogicalExpression, + children: Vec, + ) -> OptimizerResult<()> { + if logical_expression.group_id != group_id { + Err(MemoError::InvalidExpression)? + } + + // Check if the group actually exists. + let _ = self.get_group(group_id).await?; + + // Insert the child groups of the expression into the junction / children table. + LogicalChildren::insert_many(children.into_iter().map(|group_id| { + logical_children::ActiveModel { + logical_expression_id: Set(logical_expression.id), + group_id: Set(group_id), + } + })) + .exec(&self.db) + .await?; + + // Insert the expression. + let _ = logical_expression + .into_active_model() + .insert(&self.db) + .await?; + + Ok(()) + } + + async fn add_physical_expression_to_group( + &self, + group_id: Self::GroupId, + physical_expression: Self::PhysicalExpression, + children: Vec, + ) -> OptimizerResult<()> { + if physical_expression.group_id != group_id { + Err(MemoError::InvalidExpression)? + } + + // Check if the group actually exists. + let _ = self.get_group(group_id).await?; + + // Insert the child groups of the expression into the junction / children table. + PhysicalChildren::insert_many(children.into_iter().map(|group_id| { + physical_children::ActiveModel { + physical_expression_id: Set(physical_expression.id), + group_id: Set(group_id), + } + })) + .exec(&self.db) + .await?; + + // Insert the expression. + let _ = physical_expression + .into_active_model() + .insert(&self.db) + .await?; + + Ok(()) + } + + /// Note that in this function, we ignore the group ID that the logical expression contains. + async fn add_logical_expression( + &self, + expression: Self::LogicalExpression, + _children: Vec, + ) -> OptimizerResult<(Self::GroupId, Self::LogicalExpressionId)> { + // Lookup all expressions that have the same fingerprint. There may be false positives, but + // we will check for those later. + let fingerprint = expression.fingerprint; + let potential_matches = LogicalExpression::find() + .filter(logical_expression::Column::Fingerprint.eq(fingerprint)) + .all(&self.db) + .await?; + + // Of the expressions that have the same fingerprint, check if there already exists an + // expression that is exactly identical to the input expression. + let mut matches: Vec<_> = potential_matches + .into_iter() + .filter(|expr| expr == &expression) + .collect(); + assert!( + matches.len() <= 1, + "there cannot be more than 1 exact logical expression match" + ); + + // The expression already exists, so return its data. + if !matches.is_empty() { + let existing_expression = matches + .pop() + .expect("we just checked that an element exists"); + + return Ok((existing_expression.group_id, existing_expression.id)); + } + + // The expression does not exist yet, so we need to create a new group and new expression. + let group = cascades_group::ActiveModel { + winner: Set(None), + is_optimized: Set(false), + ..Default::default() + }; + + // Insert a new group. + let res = cascades_group::Entity::insert(group).exec(&self.db).await?; + + // Insert the input expression with the correct `group_id`. + let mut new_expr = expression.into_active_model(); + new_expr.group_id = Set(res.last_insert_id); + let new_expr = new_expr.insert(&self.db).await?; + + Ok((new_expr.group_id, new_expr.id)) + } +}