Skip to content

Commit

Permalink
remove top-level Expression type and rename to Default<>Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Dec 6, 2024
1 parent e075d8f commit ffae66e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 100 deletions.
54 changes: 27 additions & 27 deletions optd-mvp/src/expression/logical_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ use fxhash::hash;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug)]
pub enum LogicalExpression {
pub enum DefaultLogicalExpression {
Scan(Scan),
Filter(Filter),
Join(Join),
}

impl LogicalExpression {
impl DefaultLogicalExpression {
pub fn kind(&self) -> i16 {
match self {
LogicalExpression::Scan(_) => 0,
LogicalExpression::Filter(_) => 1,
LogicalExpression::Join(_) => 2,
DefaultLogicalExpression::Scan(_) => 0,
DefaultLogicalExpression::Filter(_) => 1,
DefaultLogicalExpression::Join(_) => 2,
}
}

Expand All @@ -46,11 +46,11 @@ impl LogicalExpression {

let kind = self.kind() as u16 as usize;
let hash = match self {
LogicalExpression::Scan(scan) => hash(scan.table.as_str()),
LogicalExpression::Filter(filter) => {
DefaultLogicalExpression::Scan(scan) => hash(scan.table.as_str()),
DefaultLogicalExpression::Filter(filter) => {
hash(&rewrite(filter.child).0) ^ hash(filter.expression.as_str())
}
LogicalExpression::Join(join) => {
DefaultLogicalExpression::Join(join) => {
// Make sure that there is a difference between `Join(A, B)` and `Join(B, A)`.
hash(&(rewrite(join.left).0 + 1))
^ hash(&(rewrite(join.right).0 + 2))
Expand Down Expand Up @@ -80,14 +80,14 @@ impl LogicalExpression {
};

match (self, other) {
(LogicalExpression::Scan(scan_left), LogicalExpression::Scan(scan_right)) => {
(DefaultLogicalExpression::Scan(scan_left), DefaultLogicalExpression::Scan(scan_right)) => {
scan_left.table == scan_right.table
}
(LogicalExpression::Filter(filter_left), LogicalExpression::Filter(filter_right)) => {
(DefaultLogicalExpression::Filter(filter_left), DefaultLogicalExpression::Filter(filter_right)) => {
rewrite(filter_left.child) == rewrite(filter_right.child)
&& filter_left.expression == filter_right.expression
}
(LogicalExpression::Join(join_left), LogicalExpression::Join(join_right)) => {
(DefaultLogicalExpression::Join(join_left), DefaultLogicalExpression::Join(join_right)) => {
rewrite(join_left.left) == rewrite(join_right.left)
&& rewrite(join_left.right) == rewrite(join_right.right)
&& join_left.expression == join_right.expression
Expand All @@ -98,9 +98,9 @@ impl LogicalExpression {

pub fn children(&self) -> Vec<GroupId> {
match self {
LogicalExpression::Scan(_) => vec![],
LogicalExpression::Filter(filter) => vec![filter.child],
LogicalExpression::Join(join) => vec![join.left, join.right],
DefaultLogicalExpression::Scan(_) => vec![],
DefaultLogicalExpression::Filter(filter) => vec![filter.child],
DefaultLogicalExpression::Join(join) => vec![join.left, join.right],
}
}
}
Expand All @@ -124,7 +124,7 @@ pub struct Join {
}

/// TODO Use a macro.
impl From<logical_expression::Model> for LogicalExpression {
impl From<logical_expression::Model> for DefaultLogicalExpression {
fn from(value: logical_expression::Model) -> Self {
match value.kind {
0 => Self::Scan(
Expand All @@ -145,8 +145,8 @@ impl From<logical_expression::Model> for LogicalExpression {
}

/// TODO Use a macro.
impl From<LogicalExpression> for logical_expression::Model {
fn from(value: LogicalExpression) -> logical_expression::Model {
impl From<DefaultLogicalExpression> for logical_expression::Model {
fn from(value: DefaultLogicalExpression) -> logical_expression::Model {
fn create_logical_expression(
kind: i16,
data: serde_json::Value,
Expand All @@ -161,15 +161,15 @@ impl From<LogicalExpression> for logical_expression::Model {

let kind = value.kind();
match value {
LogicalExpression::Scan(scan) => create_logical_expression(
DefaultLogicalExpression::Scan(scan) => create_logical_expression(
kind,
serde_json::to_value(scan).expect("unable to serialize logical `Scan`"),
),
LogicalExpression::Filter(filter) => create_logical_expression(
DefaultLogicalExpression::Filter(filter) => create_logical_expression(
kind,
serde_json::to_value(filter).expect("unable to serialize logical `Filter`"),
),
LogicalExpression::Join(join) => create_logical_expression(
DefaultLogicalExpression::Join(join) => create_logical_expression(
kind,
serde_json::to_value(join).expect("unable to serialize logical `Join`"),
),
Expand All @@ -183,16 +183,16 @@ pub use build::*;
#[cfg(test)]
mod build {
use super::*;
use crate::expression::LogicalExpression;
use crate::expression::DefaultLogicalExpression;

pub fn scan(table_schema: String) -> LogicalExpression {
LogicalExpression::Scan(Scan {
pub fn scan(table_schema: String) -> DefaultLogicalExpression {
DefaultLogicalExpression::Scan(Scan {
table: table_schema,
})
}

pub fn filter(child_group: GroupId, expression: String) -> LogicalExpression {
LogicalExpression::Filter(Filter {
pub fn filter(child_group: GroupId, expression: String) -> DefaultLogicalExpression {
DefaultLogicalExpression::Filter(Filter {
child: child_group,
expression,
})
Expand All @@ -202,8 +202,8 @@ mod build {
left_group: GroupId,
right_group: GroupId,
expression: String,
) -> LogicalExpression {
LogicalExpression::Join(Join {
) -> DefaultLogicalExpression {
DefaultLogicalExpression::Join(Join {
left: left_group,
right: right_group,
expression,
Expand Down
52 changes: 0 additions & 52 deletions optd-mvp/src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,3 @@ pub use logical_expression::*;
mod physical_expression;
pub use physical_expression::*;

/// The representation of an expression.
///
/// TODO more docs.
#[derive(Clone, Debug)]
pub enum Expression {
Logical(LogicalExpression),
Physical(PhysicalExpression),
}

/// Converts the database / JSON representation of a logical expression into an in-memory one.
impl From<crate::entities::logical_expression::Model> for Expression {
fn from(value: crate::entities::logical_expression::Model) -> Self {
Self::Logical(value.into())
}
}

/// Converts the in-memory representation of a logical expression into the database / JSON version.
///
/// # Panics
///
/// This will panic if the [`Expression`] is [`Expression::Physical`].
impl From<Expression> for crate::entities::logical_expression::Model {
fn from(value: Expression) -> Self {
let Expression::Logical(expr) = value else {
panic!("Attempted to convert an in-memory physical expression into a logical database / JSON expression");
};

expr.into()
}
}

/// Converts the database / JSON representation of a physical expression into an in-memory one.
impl From<crate::entities::physical_expression::Model> for Expression {
fn from(value: crate::entities::physical_expression::Model) -> Self {
Self::Physical(value.into())
}
}

/// Converts the in-memory representation of a physical expression into the database / JSON version.
///
/// # Panics
///
/// This will panic if the [`Expression`] is [`Expression::Physical`].
impl From<Expression> for crate::entities::physical_expression::Model {
fn from(value: Expression) -> Self {
let Expression::Physical(expr) = value else {
panic!("Attempted to convert an in-memory logical expression into a physical database / JSON expression");
};

expr.into()
}
}
20 changes: 10 additions & 10 deletions optd-mvp/src/expression/physical_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{entities::*, memo::GroupId};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PhysicalExpression {
pub enum DefaultPhysicalExpression {
TableScan(TableScan),
Filter(PhysicalFilter),
HashJoin(HashJoin),
Expand All @@ -34,7 +34,7 @@ pub struct HashJoin {
}

/// TODO Use a macro.
impl From<physical_expression::Model> for PhysicalExpression {
impl From<physical_expression::Model> for DefaultPhysicalExpression {
fn from(value: physical_expression::Model) -> Self {
match value.kind {
0 => Self::TableScan(
Expand All @@ -55,8 +55,8 @@ impl From<physical_expression::Model> for PhysicalExpression {
}

/// TODO Use a macro.
impl From<PhysicalExpression> for physical_expression::Model {
fn from(value: PhysicalExpression) -> physical_expression::Model {
impl From<DefaultPhysicalExpression> for physical_expression::Model {
fn from(value: DefaultPhysicalExpression) -> physical_expression::Model {
fn create_physical_expression(
kind: i16,
data: serde_json::Value,
Expand All @@ -70,15 +70,15 @@ impl From<PhysicalExpression> for physical_expression::Model {
}

match value {
PhysicalExpression::TableScan(scan) => create_physical_expression(
DefaultPhysicalExpression::TableScan(scan) => create_physical_expression(
0,
serde_json::to_value(scan).expect("unable to serialize physical `TableScan`"),
),
PhysicalExpression::Filter(filter) => create_physical_expression(
DefaultPhysicalExpression::Filter(filter) => create_physical_expression(
1,
serde_json::to_value(filter).expect("unable to serialize physical `Filter`"),
),
PhysicalExpression::HashJoin(join) => create_physical_expression(
DefaultPhysicalExpression::HashJoin(join) => create_physical_expression(
2,
serde_json::to_value(join).expect("unable to serialize physical `HashJoin`"),
),
Expand All @@ -92,9 +92,9 @@ pub use build::*;
#[cfg(test)]
mod build {
use super::*;
use crate::expression::PhysicalExpression;
use crate::expression::DefaultPhysicalExpression;

pub fn table_scan(table_schema: String) -> PhysicalExpression {
PhysicalExpression::TableScan(TableScan { table_schema })
pub fn table_scan(table_schema: String) -> DefaultPhysicalExpression {
DefaultPhysicalExpression::TableScan(TableScan { table_schema })
}
}
20 changes: 10 additions & 10 deletions optd-mvp/src/memo/persistent/implementation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use super::PersistentMemo;
use crate::{
entities::*,
expression::{LogicalExpression, PhysicalExpression},
expression::{DefaultLogicalExpression, DefaultPhysicalExpression},
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
OptimizerResult, DATABASE_URL,
};
Expand Down Expand Up @@ -147,7 +147,7 @@ impl PersistentMemo {
pub async fn get_physical_expression(
&self,
physical_expression_id: PhysicalExpressionId,
) -> OptimizerResult<(GroupId, PhysicalExpression)> {
) -> OptimizerResult<(GroupId, DefaultPhysicalExpression)> {
// Lookup the entity in the database via the unique expression ID.
let model = physical_expression::Entity::find_by_id(physical_expression_id.0)
.one(&self.db)
Expand All @@ -167,7 +167,7 @@ impl PersistentMemo {
pub async fn get_logical_expression(
&self,
logical_expression_id: LogicalExpressionId,
) -> OptimizerResult<(GroupId, LogicalExpression)> {
) -> OptimizerResult<(GroupId, DefaultLogicalExpression)> {
// Lookup the entity in the database via the unique expression ID.
let model = logical_expression::Entity::find_by_id(logical_expression_id.0)
.one(&self.db)
Expand Down Expand Up @@ -288,7 +288,7 @@ impl PersistentMemo {
pub async fn add_logical_expression_to_group(
&self,
group_id: GroupId,
logical_expression: LogicalExpression,
logical_expression: DefaultLogicalExpression,
children: &[GroupId],
) -> OptimizerResult<Result<LogicalExpressionId, (GroupId, LogicalExpressionId)>> {
// Check if the expression already exists anywhere in the memo table.
Expand Down Expand Up @@ -323,7 +323,7 @@ impl PersistentMemo {
.await?;

// Finally, insert the fingerprint of the logical expression as well.
let new_expr: LogicalExpression = new_model.into();
let new_expr: DefaultLogicalExpression = new_model.into();
let kind = new_expr.kind();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
Expand Down Expand Up @@ -359,7 +359,7 @@ impl PersistentMemo {
pub async fn add_physical_expression_to_group(
&self,
group_id: GroupId,
physical_expression: PhysicalExpression,
physical_expression: DefaultPhysicalExpression,
children: &[GroupId],
) -> OptimizerResult<PhysicalExpressionId> {
// Check if the group actually exists.
Expand Down Expand Up @@ -399,7 +399,7 @@ impl PersistentMemo {
/// expression should _not_ have G2 as a child, and should be replaced with G1.
pub async fn is_duplicate_logical_expression(
&self,
logical_expression: &LogicalExpression,
logical_expression: &DefaultLogicalExpression,
children: &[GroupId],
) -> OptimizerResult<Option<(GroupId, LogicalExpressionId)>> {
let model: logical_expression::Model = logical_expression.clone().into();
Expand Down Expand Up @@ -473,7 +473,7 @@ impl PersistentMemo {
/// expression, returning brand new IDs for both.
pub async fn add_group(
&self,
logical_expression: LogicalExpression,
logical_expression: DefaultLogicalExpression,
children: &[GroupId],
) -> OptimizerResult<Result<(GroupId, LogicalExpressionId), (GroupId, LogicalExpressionId)>>
{
Expand Down Expand Up @@ -517,7 +517,7 @@ impl PersistentMemo {
.await?;

// Finally, insert the fingerprint of the logical expression as well.
let new_logical_expression: LogicalExpression = new_expression.into();
let new_logical_expression: DefaultLogicalExpression = new_expression.into();
let kind = new_logical_expression.kind();

// In order to calculate a correct fingerprint, we will want to use the IDs of the root
Expand Down Expand Up @@ -606,7 +606,7 @@ impl PersistentMemo {
seen.insert(expr_id);
}

let logical_expression: LogicalExpression = model.into();
let logical_expression: DefaultLogicalExpression = model.into();
let hash = logical_expression.fingerprint_with_rewrite(&rewrites);

let fingerprint = fingerprint::ActiveModel {
Expand Down
2 changes: 1 addition & 1 deletion optd-mvp/src/memo/persistent/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async fn test_simple_tree() {
memo.cleanup().await;

// Create two scan groups.
let scan1: LogicalExpression = scan("t1".to_string());
let scan1: DefaultLogicalExpression = scan("t1".to_string());
let scan2 = scan("t2".to_string());
let (scan_id_1, scan_expr_id_1) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap();
let (scan_id_2, scan_expr_id_2) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap();
Expand Down

0 comments on commit ffae66e

Please sign in to comment.