Skip to content

Commit

Permalink
huge refactor of persistent memo implementation
Browse files Browse the repository at this point in the history
This commit completely refactors the memo table, removing the `Memo`
trait and instead placing all methods directly on the `PersistentMemo`
structure itself.

This also cleans up some code in other places.
  • Loading branch information
connortsui20 committed Nov 30, 2024
1 parent 2856496 commit aaba197
Show file tree
Hide file tree
Showing 13 changed files with 601 additions and 479 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions optd-mvp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ serde_json = "1.0.118" # Support `Hash` on `serde_json::Value` in "1.0.118".
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
trait-variant = "0.1.2" # Support `make(Send)` syntax in "0.1.2".
thiserror = "2.0"
fxhash = "0.2"

# Pin more recent versions for `-Zminimal-versions`.
async-trait = "0.1.43" # Remove lifetime parameter from "0.1.42".
async-stream = "0.3.1" # Fix unsatisfied trait bound from "0.3.0".
strum = "0.26.0" # Fix `std::marker::Sized` from "0.25.0".

2 changes: 1 addition & 1 deletion optd-mvp/src/entities/logical_children.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub enum Relation {
CascadesGroup,
#[sea_orm(
belongs_to = "super::logical_expression::Entity",
from = "Column::GroupId",
from = "Column::LogicalExpressionId",
to = "super::logical_expression::Column::Id",
on_update = "Cascade",
on_delete = "Cascade"
Expand Down
2 changes: 2 additions & 0 deletions optd-mvp/src/entities/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0
#![allow(unused_imports)]

pub use super::cascades_group::Entity as CascadesGroup;
pub use super::fingerprint::Entity as Fingerprint;
pub use super::logical_children::Entity as LogicalChildren;
Expand Down
107 changes: 83 additions & 24 deletions optd-mvp/src/expression/logical_expression.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,91 @@
//! Definition of logical expressions / relations in the Cascades query optimization framework.
//!
//! FIXME: All fields are placeholders, and group IDs are just represented as i32 for now.
//! FIXME: Representation needs to know how to "rewrite" child group IDs to whatever a fingerprint
//! will need.
//! FIXME: All fields are placeholders.
//!
//! TODO figure out if each relation should be in a different submodule.
//! TODO Remove dead code.
//! TODO Figure out if each relation should be in a different submodule.
//! TODO This entire file is a WIP.
use crate::entities::*;
#![allow(dead_code)]

use crate::{entities::*, memo::GroupId};
use fxhash::hash;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum LogicalExpression {
Scan(Scan),
Filter(Filter),
Join(Join),
}

#[derive(Serialize, Deserialize, Clone, Debug)]
/// FIXME: Figure out how to make everything unsigned instead of signed.
impl LogicalExpression {
pub fn kind(&self) -> i16 {
match self {
LogicalExpression::Scan(_) => 0,
LogicalExpression::Filter(_) => 1,
LogicalExpression::Join(_) => 2,
}
}

/// Definitions of custom fingerprinting strategies for each kind of logical expression.
pub fn fingerprint(&self) -> i64 {
self.fingerprint_with_rewrite(&[])
}

/// Calculates the fingerprint of a given expression, but replaces all of the children group IDs
/// with a new group ID if it is listed in the input `rewrites` list.
///
/// TODO Allow each expression to implement a trait that does this.
pub fn fingerprint_with_rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> i64 {
// Closure that rewrites a group ID if needed.
let rewrite = |x: GroupId| {
if rewrites.is_empty() {
return x;
}

if let Some(i) = rewrites.iter().position(|(curr, _new)| &x == curr) {
assert_eq!(rewrites[i].0, x);
rewrites[i].1
} else {
x
}
};

let kind = self.kind() as u16 as usize;
let hash = match self {
LogicalExpression::Scan(scan) => hash(scan.table_schema.as_str()),
LogicalExpression::Filter(filter) => {
hash(&rewrite(filter.child).0) ^ hash(filter.expression.as_str())
}
LogicalExpression::Join(join) => {
hash(&rewrite(join.left).0)
^ hash(&rewrite(join.right).0)
^ hash(join.expression.as_str())
}
};

// Mask out the bottom 16 bits of `hash` and replace them with `kind`.
((hash & !0xFFFF) | kind) as i64
}
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct Scan {
table_schema: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct Filter {
child: i32,
child: GroupId,
expression: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct Join {
left: i32,
right: i32,
left: GroupId,
right: GroupId,
expression: String,
}

Expand Down Expand Up @@ -71,17 +125,18 @@ impl From<LogicalExpression> for logical_expression::Model {
}
}

let kind = value.kind();
match value {
LogicalExpression::Scan(scan) => create_logical_expression(
0,
kind,
serde_json::to_value(scan).expect("unable to serialize logical `Scan`"),
),
LogicalExpression::Filter(filter) => create_logical_expression(
1,
kind,
serde_json::to_value(filter).expect("unable to serialize logical `Filter`"),
),
LogicalExpression::Join(join) => create_logical_expression(
2,
kind,
serde_json::to_value(join).expect("unable to serialize logical `Join`"),
),
}
Expand All @@ -94,24 +149,28 @@ pub use build::*;
#[cfg(test)]
mod build {
use super::*;
use crate::expression::Expression;
use crate::expression::LogicalExpression;

pub fn scan(table_schema: String) -> Expression {
Expression::Logical(LogicalExpression::Scan(Scan { table_schema }))
pub fn scan(table_schema: String) -> LogicalExpression {
LogicalExpression::Scan(Scan { table_schema })
}

pub fn filter(child_group: i32, expression: String) -> Expression {
Expression::Logical(LogicalExpression::Filter(Filter {
pub fn filter(child_group: GroupId, expression: String) -> LogicalExpression {
LogicalExpression::Filter(Filter {
child: child_group,
expression,
}))
})
}

pub fn join(left_group: i32, right_group: i32, expression: String) -> Expression {
Expression::Logical(LogicalExpression::Join(Join {
pub fn join(
left_group: GroupId,
right_group: GroupId,
expression: String,
) -> LogicalExpression {
LogicalExpression::Join(Join {
left: left_group,
right: right_group,
expression,
}))
})
}
}
45 changes: 26 additions & 19 deletions optd-mvp/src/expression/physical_expression.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
//! Definition of physical expressions / operators in the Cascades query optimization framework.
//!
//! FIXME: All fields are placeholders, and group IDs are just represented as i32 for now.
//! FIXME: All fields are placeholders.
//!
//! TODO figure out if each operator should be in a different submodule.
//! TODO Remove dead code.
//! TODO Figure out if each operator should be in a different submodule.
//! TODO This entire file is a WIP.
use crate::entities::*;
#![allow(dead_code)]

use crate::{entities::*, memo::GroupId};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PhysicalExpression {
TableScan(TableScan),
Filter(PhysicalFilter),
HashJoin(HashJoin),
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct TableScan {
table_schema: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct PhysicalFilter {
child: i32,
child: GroupId,
expression: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct HashJoin {
left: i32,
right: i32,
left: GroupId,
right: GroupId,
expression: String,
}

Expand Down Expand Up @@ -92,24 +95,28 @@ pub use build::*;
#[cfg(test)]
mod build {
use super::*;
use crate::expression::Expression;
use crate::expression::PhysicalExpression;

pub fn table_scan(table_schema: String) -> Expression {
Expression::Physical(PhysicalExpression::TableScan(TableScan { table_schema }))
pub fn table_scan(table_schema: String) -> PhysicalExpression {
PhysicalExpression::TableScan(TableScan { table_schema })
}

pub fn filter(child_group: i32, expression: String) -> Expression {
Expression::Physical(PhysicalExpression::Filter(PhysicalFilter {
pub fn filter(child_group: GroupId, expression: String) -> PhysicalExpression {
PhysicalExpression::Filter(PhysicalFilter {
child: child_group,
expression,
}))
})
}

pub fn hash_join(left_group: i32, right_group: i32, expression: String) -> Expression {
Expression::Physical(PhysicalExpression::HashJoin(HashJoin {
pub fn hash_join(
left_group: GroupId,
right_group: GroupId,
expression: String,
) -> PhysicalExpression {
PhysicalExpression::HashJoin(HashJoin {
left: left_group,
right: right_group,
expression,
}))
})
}
}
12 changes: 0 additions & 12 deletions optd-mvp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,3 @@ pub type OptimizerResult<T> = Result<T, OptimizerError>;
pub async fn migrate(db: &DatabaseConnection) -> Result<(), DbErr> {
Migrator::refresh(db).await
}

/// Helper function for hashing expression data.
///
/// TODO remove this.
fn hash_expression(kind: i16, data: &serde_json::Value) -> i64 {
use std::hash::{DefaultHasher, Hash, Hasher};

let mut hasher = DefaultHasher::new();
kind.hash(&mut hasher);
data.hash(&mut hasher);
hasher.finish() as i64
}
Loading

0 comments on commit aaba197

Please sign in to comment.