diff --git a/diesel/src/expression/functions/aggregate_expressions.rs b/diesel/src/expression/functions/aggregate_expressions.rs new file mode 100644 index 000000000000..6bd1406b6bf1 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions.rs @@ -0,0 +1,219 @@ +use crate::backend::Backend; +use crate::expression::{is_aggregate, AsExpression, ValidGrouping}; +use crate::query_builder::{AstPass, QueryFragment, QueryId}; +use crate::sql_types::Bool; +use crate::{AppearsOnTable, Expression, QueryResult, SelectableExpression}; + +macro_rules! empty_clause { + ($name: ident) => { + #[derive(Debug, Clone, Copy, QueryId)] + pub struct $name; + + impl crate::query_builder::QueryFragment for $name { + fn walk_ast<'b>( + &'b self, + _pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + Ok(()) + } + } + }; +} + +mod aggregate_filter; +mod aggregate_order; +pub(crate) mod frame_clause; +mod over_clause; +mod partition_by; +mod prefix; +mod within_group; + +use self::aggregate_filter::{Filter, FilterDsl, NoFilter}; +use self::aggregate_order::{NoOrder, Order, OrderAggregateDsl, OrderWindowDsl}; +use self::frame_clause::{FrameDsl, NoFrame}; +use self::over_clause::{NoWindow, OverClause, OverDsl}; +use self::partition_by::{PartitionBy, PartitionByDsl}; +use self::prefix::{All, AllDsl, Distinct, DistinctDsl, NoPrefix}; +use self::within_group::{NoWithin, WithinGroup, WithinGroupDsl}; + +#[derive(QueryId)] +pub struct AggregateExpression< + Fn, + Prefix = NoPrefix, + Order = NoOrder, + Filter = NoFilter, + Within = NoWithin, + Window = NoWindow, +> { + prefix: Prefix, + function: Fn, + order: Order, + filter: Filter, + within_group: Within, + window: Window, +} + +impl QueryFragment + for AggregateExpression +where + Fn: FunctionFragment, + Prefix: QueryFragment, + Order: QueryFragment, + Filter: QueryFragment, + Within: QueryFragment, + Window: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + self.function.walk_name(pass.reborrow())?; + pass.push_sql("("); + self.prefix.walk_ast(pass.reborrow())?; + self.function.walk_arguments(pass.reborrow())?; + self.order.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + self.within_group.walk_ast(pass.reborrow())?; + self.filter.walk_ast(pass.reborrow())?; + self.window.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +impl ValidGrouping + for AggregateExpression +where + Fn: ValidGrouping, +{ + type IsAggregate = >::IsAggregate; +} + +impl ValidGrouping + for AggregateExpression< + Fn, + Prefix, + Order, + Filter, + Within, + OverClause, + > +where + Fn: ValidGrouping, +{ + // not sure about that, check this + type IsAggregate = is_aggregate::No; +} + +impl Expression + for AggregateExpression +where + Fn: Expression, +{ + type SqlType = ::SqlType; +} + +impl AppearsOnTable + for AggregateExpression +where + Self: Expression, + Fn: AppearsOnTable, +{ +} + +impl SelectableExpression + for AggregateExpression +where + Self: Expression, + Fn: SelectableExpression, +{ +} + +pub trait WindowFunction {} +pub trait AggregateFunction {} +pub trait FunctionFragment { + fn walk_name<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>; + + fn walk_arguments<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>; +} + +// TODO: write helper types for all functions +pub trait AggregateExpressionMethods: Sized { + fn distinct(self) -> Self::Output + where + Self: DistinctDsl, + { + ::distinct(self) + } + + fn all(self) -> Self::Output + where + Self: AllDsl, + { + ::all(self) + } + + // todo: do we want `or_filter` as well? + fn filter_aggregate

(self, f: P) -> Self::Output + where + P: AsExpression, + Self: FilterDsl, + { + >::filter(self, f.as_expression()) + } + + fn order_aggregate(self, o: O) -> Self::Output + where + Self: OrderAggregateDsl, + { + >::order(self, o) + } + + // todo: restrict this to order set aggregates + // (we don't have any in diesel yet) + fn within_group(self, o: O) -> Self::Output + where + Self: WithinGroupDsl, + { + >::within_group(self, o) + } +} + +impl AggregateExpressionMethods for T {} + +pub trait WindowExpressionMethods: Sized { + fn over(self) -> Self::Output + where + Self: OverDsl, + { + ::over(self) + } + + // todo: do we want `or_filter` as well? + fn filter_window

(self, f: P) -> Self::Output + where + P: AsExpression, + Self: FilterDsl, + { + >::filter(self, f.as_expression()) + } + + fn partition_by(self, expr: E) -> Self::Output + where + Self: PartitionByDsl, + { + >::partition_by(self, expr) + } + + fn window_order(self, expr: E) -> Self::Output + where + Self: OrderWindowDsl, + { + >::order(self, expr) + } + + fn frame_by(self, expr: E) -> Self::Output + where + Self: FrameDsl, + { + >::frame(self, expr) + } +} + +impl WindowExpressionMethods for T {} diff --git a/diesel/src/expression/functions/aggregate_expressions/aggregate_filter.rs b/diesel/src/expression/functions/aggregate_expressions/aggregate_filter.rs new file mode 100644 index 000000000000..c6beb213a819 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/aggregate_filter.rs @@ -0,0 +1,105 @@ +use super::aggregate_order::NoOrder; +use super::prefix::NoPrefix; +use super::AggregateExpression; +use super::AggregateFunction; +use super::NoWindow; +use super::NoWithin; +use crate::query_builder::where_clause::NoWhereClause; +use crate::query_builder::where_clause::WhereAnd; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::sql_types::Bool; +use crate::Expression; +use crate::QueryResult; + +empty_clause!(NoFilter); + +#[derive(QueryId, Copy, Clone)] +pub struct Filter

(P); + +impl

QueryFragment for Filter

+where + P: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" FILTER ("); + self.0.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +pub trait FilterDsl

{ + type Output; + + fn filter(self, f: P) -> Self::Output; +} + +impl FilterDsl

for T +where + T: AggregateFunction, + // todo: allow nullable bools here + P: Expression, +{ + type Output = + AggregateExpression>::Output>>; + + fn filter(self, f: P) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: NoOrder, + filter: Filter(NoWhereClause.and(f)), + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl FilterDsl

+ for AggregateExpression, Within, Window> +where + // todo: allow nullable bools here + F: WhereAnd

, +{ + type Output = + AggregateExpression>::Output>, Within, Window>; + + fn filter(self, f: P) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: Filter(WhereAnd::

::and(self.filter.0, f)), + within_group: self.within_group, + window: self.window, + } + } +} + +impl FilterDsl

+ for AggregateExpression +where + // todo: allow nullable bools here + NoWhereClause: WhereAnd

, +{ + type Output = AggregateExpression< + Fn, + Prefix, + Order, + Filter<>::Output>, + Within, + Window, + >; + + fn filter(self, f: P) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: Filter(WhereAnd::

::and(NoWhereClause, f)), + within_group: self.within_group, + window: self.window, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/aggregate_order.rs b/diesel/src/expression/functions/aggregate_expressions/aggregate_order.rs new file mode 100644 index 000000000000..8399c6e347ed --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/aggregate_order.rs @@ -0,0 +1,108 @@ +use super::AggregateFunction; +use super::NoFilter; +use super::NoPrefix; +use super::NoWindow; +use super::NoWithin; +use super::{over_clause::OverClause, AggregateExpression}; +use crate::query_builder::order_clause::OrderClause; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::{Expression, QueryResult}; + +empty_clause!(NoOrder); + +#[derive(QueryId, Copy, Clone)] +pub struct Order(OrderClause); + +impl QueryFragment for Order +where + OrderClause: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + self.0.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +pub trait OrderAggregateDsl { + type Output; + + fn order(self, expr: E) -> Self::Output; +} + +impl OrderAggregateDsl for T +where + T: AggregateFunction, + E: Expression, +{ + type Output = AggregateExpression>; + + fn order(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: Order(OrderClause(expr)), + filter: NoFilter, + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl OrderAggregateDsl + for AggregateExpression +{ + type Output = AggregateExpression, Filter>; + + fn order(self, expr: O) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: Order(OrderClause(expr)), + filter: self.filter, + within_group: self.within_group, + window: NoWindow, + } + } +} + +pub trait OrderWindowDsl { + type Output; + + fn order(self, expr: O) -> Self::Output; +} + +impl OrderWindowDsl + for AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, + > +{ + type Output = AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, Frame>, + >; + + fn order(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self.function, + order: NoOrder, + filter: self.filter, + within_group: NoWithin, + window: OverClause { + partition_by: self.window.partition_by, + order: Order(OrderClause(expr)), + frame_clause: self.window.frame_clause, + }, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/frame_clause.rs b/diesel/src/expression/functions/aggregate_expressions/frame_clause.rs new file mode 100644 index 000000000000..85d4442fa0e5 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/frame_clause.rs @@ -0,0 +1,282 @@ +use crate::query_builder::{QueryFragment, QueryId}; +use crate::sql_types::BigInt; + +use super::aggregate_order::{NoOrder, OrderWindowDsl}; +use super::over_clause::OverClause; +use super::prefix::NoPrefix; +use super::within_group::NoWithin; +use super::AggregateExpression; + +empty_clause!(NoFrame); + +#[derive(QueryId, Copy, Clone)] +pub struct FrameClause(F); + +impl QueryFragment for FrameClause +where + F: QueryFragment, +{ + fn walk_ast<'b>( + &'b self, + pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + self.0.walk_ast(pass)?; + Ok(()) + } +} + +macro_rules! simple_frame_expr { + ($name: ident, $kind: expr) => { + #[derive(QueryId, Clone, Copy)] + pub struct $name; + + impl QueryFragment for $name { + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + pass.push_sql($kind); + Ok(()) + } + } + }; +} + +// kinds +simple_frame_expr!(Range, " RANGE "); +simple_frame_expr!(Rows, " ROWS "); +simple_frame_expr!(Groups, " GROUPS "); + +// start & end +simple_frame_expr!(UnboundedPreceding, "UNBOUNDED PRECEDING "); +simple_frame_expr!(CurrentRow, "CURRENT ROW "); +simple_frame_expr!(UnboundedFollowing, "UNBOUNDED FOLLOWING "); + +// exclusion +simple_frame_expr!(ExcludeCurrentRow, "EXCLUDE CURRENT ROW "); +simple_frame_expr!(ExcludeGroup, "EXCLUDE GROUP "); +simple_frame_expr!(ExcludeTies, "EXCLUDE TIES "); +simple_frame_expr!(ExcludeNoOthers, "EXCLUDE NO OTHERS "); + +#[derive(QueryId, Clone, Copy)] +pub struct OffsetPreceding(i64); + +impl QueryFragment for OffsetPreceding { + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + pass.push_bind_param::(&self.0)?; + pass.push_sql(" PRECEDING "); + Ok(()) + } +} + +#[derive(QueryId, Clone, Copy)] +pub struct OffsetFollowing(i64); + +impl QueryFragment for OffsetFollowing { + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + pass.push_bind_param::(&self.0)?; + pass.push_sql(" FOLLOWING "); + Ok(()) + } +} + +pub trait FrameDsl { + type Output; + + fn frame(self, expr: F) -> Self::Output; +} + +impl FrameDsl + for AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, + > +where + E: FrameClauseExpression, +{ + type Output = AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause>, + >; + + fn frame(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: self.filter, + within_group: self.within_group, + window: OverClause { + partition_by: self.window.partition_by, + order: self.window.order, + frame_clause: FrameClause(expr), + }, + } + } +} + +pub trait FrameClauseExpression {} + +pub trait FrameClauseBound {} +impl FrameClauseBound for UnboundedFollowing {} +impl FrameClauseBound for UnboundedPreceding {} +impl FrameClauseBound for CurrentRow {} +impl FrameClauseBound for OffsetFollowing {} +impl FrameClauseBound for OffsetPreceding {} + +pub trait FrameCauseExclusion {} + +impl FrameCauseExclusion for ExcludeGroup {} +impl FrameCauseExclusion for ExcludeNoOthers {} +impl FrameCauseExclusion for ExcludeTies {} +impl FrameCauseExclusion for ExcludeCurrentRow {} + +pub trait FrameBoundDsl { + fn preceding(self) -> OffsetPreceding; + fn following(self) -> OffsetFollowing; +} + +impl FrameBoundDsl for i64 { + fn preceding(self) -> OffsetPreceding { + OffsetPreceding(self) + } + + fn following(self) -> OffsetFollowing { + OffsetFollowing(self) + } +} + +empty_clause!(NoExclusion); + +#[derive(QueryId, Copy, Clone)] +pub struct StartFrame { + kind: Kind, + start: Start, + exclusion: Exclusion, +} + +impl QueryFragment for StartFrame +where + Kind: QueryFragment, + Start: QueryFragment, + Exclusion: QueryFragment, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + self.kind.walk_ast(pass.reborrow())?; + self.start.walk_ast(pass.reborrow())?; + self.exclusion.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +impl FrameClauseExpression for StartFrame {} + +#[derive(QueryId, Copy, Clone)] +pub struct BetweenFrame { + kind: Kind, + start: Start, + end: End, + exclusion: Exclusion, +} + +impl QueryFragment + for BetweenFrame +where + Kind: QueryFragment, + Start: QueryFragment, + End: QueryFragment, + Exclusion: QueryFragment, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>, + ) -> crate::QueryResult<()> { + self.kind.walk_ast(pass.reborrow())?; + pass.push_sql(" BETWEEN "); + self.start.walk_ast(pass.reborrow())?; + pass.push_sql(" AND "); + self.exclusion.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +impl FrameClauseExpression + for BetweenFrame +{ +} + +pub trait FrameClauseDslHelper: Sized {} + +pub trait FrameClauseDsl: FrameClauseDslHelper { + fn start_with(self, start: E) -> StartFrame { + StartFrame { + kind: self, + start, + exclusion: NoExclusion, + } + } + fn start_with_exclusion( + self, + start: E1, + exclusion: E2, + ) -> StartFrame { + StartFrame { + kind: self, + start, + exclusion, + } + } + + fn between( + self, + start: E1, + end: E2, + ) -> BetweenFrame { + BetweenFrame { + kind: self, + start, + end, + exclusion: NoExclusion, + } + } + fn between_with_exclusion< + E1: FrameClauseBound, + E2: FrameClauseBound, + E3: FrameCauseExclusion, + >( + self, + start: E1, + end: E2, + exclusion: E3, + ) -> BetweenFrame { + BetweenFrame { + kind: self, + start, + end, + exclusion, + } + } +} + +impl FrameClauseDsl for T where T: FrameClauseDslHelper {} + +impl FrameClauseDslHelper for Range {} +impl FrameClauseDslHelper for Rows {} +impl FrameClauseDslHelper for Groups {} diff --git a/diesel/src/expression/functions/aggregate_expressions/over_clause.rs b/diesel/src/expression/functions/aggregate_expressions/over_clause.rs new file mode 100644 index 000000000000..722038095958 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/over_clause.rs @@ -0,0 +1,85 @@ +use super::aggregate_filter::NoFilter; +use super::aggregate_order::NoOrder; +use super::partition_by::NoPartition; +use super::prefix::NoPrefix; +use super::within_group::NoWithin; +use super::AggregateExpression; +use super::NoFrame; +use super::WindowFunction; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::QueryResult; + +empty_clause!(NoWindow); + +#[derive(Clone, Copy, QueryId)] +pub struct OverClause { + pub(crate) partition_by: Partition, + pub(crate) order: Order, + pub(crate) frame_clause: Frame, +} + +impl QueryFragment for OverClause +where + Partition: QueryFragment, + Order: QueryFragment, + Frame: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" OVER ("); + self.partition_by.walk_ast(pass.reborrow())?; + self.order.walk_ast(pass.reborrow())?; + self.frame_clause.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +pub trait OverDsl { + type Output; + + fn over(self) -> Self::Output; +} + +impl OverDsl for F +where + F: WindowFunction, +{ + type Output = AggregateExpression; + + fn over(self) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: NoWithin, + window: OverClause { + partition_by: NoPartition, + order: NoOrder, + frame_clause: NoFrame, + }, + } + } +} + +impl OverDsl + for AggregateExpression +{ + type Output = AggregateExpression; + + fn over(self) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self.function, + order: NoOrder, + filter: self.filter, + within_group: NoWithin, + window: OverClause { + partition_by: NoPartition, + order: NoOrder, + frame_clause: NoFrame, + }, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/partition_by.rs b/diesel/src/expression/functions/aggregate_expressions/partition_by.rs new file mode 100644 index 000000000000..daf6f6626bfd --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/partition_by.rs @@ -0,0 +1,65 @@ +use super::aggregate_order::NoOrder; +use super::over_clause::OverClause; +use super::prefix::NoPrefix; +use super::within_group::NoWithin; +use super::AggregateExpression; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::QueryResult; + +empty_clause!(NoPartition); + +#[derive(QueryId, Clone, Copy)] +pub struct PartitionBy(T); + +impl QueryFragment for PartitionBy +where + T: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" PARTITION BY "); + self.0.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +pub trait PartitionByDsl { + type Output; + + fn partition_by(self, expr: E) -> Self::Output; +} + +impl PartitionByDsl + for AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, + > +{ + type Output = AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, Order, Frame>, + >; + + fn partition_by(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self.function, + order: NoOrder, + filter: self.filter, + within_group: NoWithin, + window: OverClause { + partition_by: PartitionBy(expr), + order: self.window.order, + frame_clause: self.window.frame_clause, + }, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/prefix.rs b/diesel/src/expression/functions/aggregate_expressions/prefix.rs new file mode 100644 index 000000000000..0c356f14cc3f --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/prefix.rs @@ -0,0 +1,116 @@ +use super::AggregateExpression; +use super::AggregateFunction; +use super::NoFilter; +use super::NoOrder; +use super::NoWindow; +use super::NoWithin; +use crate::query_builder::{AstPass, QueryFragment, QueryId}; +use crate::QueryResult; + +empty_clause!(NoPrefix); + +#[derive(Debug, Clone, Copy, QueryId)] +pub struct All; + +impl QueryFragment for All { + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" ALL "); + Ok(()) + } +} + +#[derive(Debug, Clone, Copy, QueryId)] +pub struct Distinct; + +impl QueryFragment for Distinct { + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" DISTINCT "); + Ok(()) + } +} + +pub trait DistinctDsl { + type Output; + + fn distinct(self) -> Self::Output; +} + +impl DistinctDsl for T +where + T: AggregateFunction, +{ + type Output = AggregateExpression; + + fn distinct(self) -> Self::Output { + AggregateExpression { + prefix: Distinct, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl DistinctDsl + for AggregateExpression +where + T: AggregateFunction, +{ + type Output = AggregateExpression; + + fn distinct(self) -> Self::Output { + AggregateExpression { + prefix: Distinct, + function: self.function, + order: self.order, + filter: self.filter, + within_group: self.within_group, + window: self.window, + } + } +} + +pub trait AllDsl { + type Output; + + fn all(self) -> Self::Output; +} + +impl AllDsl for T +where + T: AggregateFunction, +{ + type Output = AggregateExpression; + + fn all(self) -> Self::Output { + AggregateExpression { + prefix: All, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl AllDsl + for AggregateExpression +where + T: AggregateFunction, +{ + type Output = AggregateExpression; + + fn all(self) -> Self::Output { + AggregateExpression { + prefix: All, + function: self.function, + order: self.order, + filter: self.filter, + within_group: self.within_group, + window: self.window, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/within_group.rs b/diesel/src/expression/functions/aggregate_expressions/within_group.rs new file mode 100644 index 000000000000..f656f5cd4591 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/within_group.rs @@ -0,0 +1,88 @@ +use super::AggregateExpression; +use super::AggregateFunction; +use super::All; +use super::NoFilter; +use super::NoOrder; +use super::NoPrefix; +use super::NoWindow; +use crate::query_builder::order_clause::OrderClause; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::Expression; +use crate::QueryResult; + +empty_clause!(NoWithin); + +#[derive(QueryId, Copy, Clone)] +pub struct WithinGroup(OrderClause); + +impl QueryFragment for WithinGroup +where + OrderClause: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" WITHIN GROUP ("); + self.0.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +pub trait WithinGroupDsl { + type Output; + + fn within_group(self, expr: E) -> Self::Output; +} + +impl WithinGroupDsl for T +where + T: AggregateFunction, + E: Expression, +{ + type Output = AggregateExpression>; + + fn within_group(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: WithinGroup(OrderClause(expr)), + window: NoWindow, + } + } +} + +impl WithinGroupDsl + for AggregateExpression +{ + type Output = AggregateExpression>; + + fn within_group(self, expr: O) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: self.filter, + within_group: WithinGroup(OrderClause(expr)), + window: NoWindow, + } + } +} + +impl WithinGroupDsl + for AggregateExpression +{ + type Output = AggregateExpression>; + + fn within_group(self, expr: O) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: self.filter, + within_group: WithinGroup(OrderClause(expr)), + window: NoWindow, + } + } +} diff --git a/diesel/src/expression/functions/mod.rs b/diesel/src/expression/functions/mod.rs index db8f79e7a730..c0fb882c6d1b 100644 --- a/diesel/src/expression/functions/mod.rs +++ b/diesel/src/expression/functions/mod.rs @@ -94,6 +94,7 @@ macro_rules! no_arg_sql_function { }; } +pub(crate) mod aggregate_expressions; pub(crate) mod aggregate_folding; pub(crate) mod aggregate_ordering; pub(crate) mod date_and_time; diff --git a/diesel/src/expression/mod.rs b/diesel/src/expression/mod.rs index 97787238d4bd..068743903cd7 100644 --- a/diesel/src/expression/mod.rs +++ b/diesel/src/expression/mod.rs @@ -88,6 +88,20 @@ pub(crate) mod dsl { #[cfg(feature = "mysql_backend")] pub use crate::mysql::query_builder::DuplicatedKeys; + + pub use super::functions::aggregate_expressions::AggregateExpressionMethods; + pub use super::functions::aggregate_expressions::WindowExpressionMethods; + + pub use super::functions::aggregate_expressions::frame_clause::{ + FrameBoundDsl, FrameClauseDsl, + }; + + pub mod frame { + pub use super::super::functions::aggregate_expressions::frame_clause::{ + CurrentRow, ExcludeCurrentRow, ExcludeGroup, ExcludeNoOthers, ExcludeTies, Groups, + Range, Rows, UnboundedFollowing, UnboundedPreceding, + }; + } } #[doc(inline)] diff --git a/diesel/src/internal/mod.rs b/diesel/src/internal/mod.rs index 5ba2e38905ff..e0451a3c135b 100644 --- a/diesel/src/internal/mod.rs +++ b/diesel/src/internal/mod.rs @@ -6,4 +6,5 @@ pub mod alias_macro; pub mod derives; pub mod operators_macro; +pub mod sql_functions; pub mod table_macro; diff --git a/diesel/src/internal/sql_functions.rs b/diesel/src/internal/sql_functions.rs new file mode 100644 index 000000000000..652529b9cc4c --- /dev/null +++ b/diesel/src/internal/sql_functions.rs @@ -0,0 +1,4 @@ +#[doc(hidden)] +pub use crate::expression::functions::aggregate_expressions::{ + AggregateFunction, FunctionFragment, WindowFunction, +}; diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs index 1f76d9d9adfd..95b45d28a08a 100644 --- a/diesel_derives/src/sql_function.rs +++ b/diesel_derives/src/sql_function.rs @@ -102,6 +102,7 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping}; use diesel::query_builder::{QueryFragment, AstPass}; use diesel::sql_types::*; + use diesel::internal::sql_functions::*; use super::*; #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)] @@ -143,15 +144,20 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool } // __DieselInternal is what we call DB normally - impl #impl_generics_internal QueryFragment<__DieselInternal> + impl #impl_generics_internal FunctionFragment<__DieselInternal> for #fn_name #ty_generics where __DieselInternal: diesel::backend::Backend, #(#arg_name: QueryFragment<__DieselInternal>,)* { + + fn walk_name<'__b>(&'__b self, mut pass: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> { + pass.push_sql(#sql_name); + Ok(()) + } + #[allow(unused_assignments)] - fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{ - out.push_sql(concat!(#sql_name, "(")); + fn walk_arguments<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> { // we unroll the arguments manually here, to prevent borrow check issues let mut needs_comma = false; #( @@ -163,6 +169,21 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool needs_comma = true; } )* + Ok(()) + } + } + + // __DieselInternal is what we call DB normally + impl #impl_generics_internal QueryFragment<__DieselInternal> + for #fn_name #ty_generics + where + __DieselInternal: diesel::backend::Backend, + #(#arg_name: QueryFragment<__DieselInternal>,)* + { + fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{ + self.walk_name(out.reborrow())?; + out.push_sql("("); + self.walk_arguments(out.reborrow())?; out.push_sql(")"); Ok(()) } @@ -183,6 +204,10 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool { type IsAggregate = diesel::expression::is_aggregate::Yes; } + + impl #impl_generics AggregateFunction for #fn_name #ty_generics {} + // this might need a separate marker attribute? + impl #impl_generics WindowFunction for #fn_name #ty_generics {} }; if is_supported_on_sqlite { tokens = quote! { diff --git a/diesel_tests/tests/aggregate_expressions.rs b/diesel_tests/tests/aggregate_expressions.rs new file mode 100644 index 000000000000..aab62c8c8514 --- /dev/null +++ b/diesel_tests/tests/aggregate_expressions.rs @@ -0,0 +1,73 @@ +use crate::schema::connection_with_sean_and_tess_in_users_table; +use crate::schema::users; +use diesel::dsl::{ + self, frame, AggregateExpressionMethods, FrameBoundDsl, FrameClauseDsl, WindowExpressionMethods, +}; +use diesel::prelude::*; + +#[test] +fn test1() { + let mut conn = connection_with_sean_and_tess_in_users_table(); + + let query = users::table.select(dsl::count(users::id).filter_aggregate(users::name.eq("Sean"))); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + + let query2 = users::table.select( + dsl::count(users::id) + .distinct() + .filter_aggregate(users::name.eq("Sean")), + ); + dbg!(diesel::debug_query::(&query2)); + let res = query2.get_result::(&mut conn).unwrap(); + dbg!(res); + + let query3 = users::table.select( + dsl::count(users::id) + .distinct() + .filter_aggregate(users::name.eq("Sean")) + .order_aggregate(users::id), + ); + dbg!(diesel::debug_query::(&query3)); + let res = query3.get_result::(&mut conn).unwrap(); + dbg!(res); + todo!() +} + +#[test] +fn test2() { + let mut conn = connection_with_sean_and_tess_in_users_table(); + + let query = users::table.select(dsl::count(users::id).over()); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 2); + + let query = users::table.select(dsl::count(users::id).over().partition_by(users::name)); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + + let query = users::table.select(dsl::count(users::id).over().window_order(users::name)); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + + let query = users::table.select( + dsl::count(users::id) + .over() + .window_order(users::name) + .partition_by(users::name) + .frame_by(frame::Rows.start_with(2.preceding())), + ); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + todo!() +} diff --git a/diesel_tests/tests/lib.rs b/diesel_tests/tests/lib.rs index 07743347201f..1b7d6341ddf1 100644 --- a/diesel_tests/tests/lib.rs +++ b/diesel_tests/tests/lib.rs @@ -6,6 +6,7 @@ extern crate assert_matches; #[macro_use] extern crate diesel; +mod aggregate_expressions; mod alias; #[cfg(not(feature = "sqlite"))] mod annotations;