Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft for supporting window functions and other aggregate function expressions #4322

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions diesel/src/expression/functions/aggregate_expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
use crate::backend::Backend;
use crate::expression::{is_aggregate, AsExpression, ValidGrouping};
use crate::query_builder::{AstPass, QueryFragment, QueryId};
use crate::sql_types::Bool;
use crate::{AppearsOnTable, Expression, QueryResult, SelectableExpression};

macro_rules! empty_clause {
($name: ident) => {
#[derive(Debug, Clone, Copy, QueryId)]
pub struct $name;

impl crate::query_builder::QueryFragment<diesel::pg::Pg> for $name {
fn walk_ast<'b>(
&'b self,
_pass: crate::query_builder::AstPass<'_, 'b, diesel::pg::Pg>,
) -> crate::QueryResult<()> {
Ok(())
}
}
};
}

mod aggregate_filter;
mod aggregate_order;
pub(crate) mod frame_clause;
mod over_clause;
mod partition_by;
mod prefix;
mod within_group;

use self::aggregate_filter::{Filter, FilterDsl, NoFilter};
use self::aggregate_order::{NoOrder, Order, OrderAggregateDsl, OrderWindowDsl};
use self::frame_clause::{FrameDsl, NoFrame};
use self::over_clause::{NoWindow, OverClause, OverDsl};
use self::partition_by::{PartitionBy, PartitionByDsl};
use self::prefix::{All, AllDsl, Distinct, DistinctDsl, NoPrefix};
use self::within_group::{NoWithin, WithinGroup, WithinGroupDsl};

#[derive(QueryId)]
pub struct AggregateExpression<
Fn,
Prefix = NoPrefix,
Order = NoOrder,
Filter = NoFilter,
Within = NoWithin,
Window = NoWindow,
> {
prefix: Prefix,
function: Fn,
order: Order,
filter: Filter,
within_group: Within,
window: Window,
}

impl<Fn, Prefix, Order, Filter, Within, Window> QueryFragment<diesel::pg::Pg>
for AggregateExpression<Fn, Prefix, Order, Filter, Within, Window>
where
Fn: FunctionFragment<diesel::pg::Pg>,
Prefix: QueryFragment<diesel::pg::Pg>,
Order: QueryFragment<diesel::pg::Pg>,
Filter: QueryFragment<diesel::pg::Pg>,
Within: QueryFragment<diesel::pg::Pg>,
Window: QueryFragment<diesel::pg::Pg>,
{
fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> {
self.function.walk_name(pass.reborrow())?;
pass.push_sql("(");
self.prefix.walk_ast(pass.reborrow())?;
self.function.walk_arguments(pass.reborrow())?;
self.order.walk_ast(pass.reborrow())?;
pass.push_sql(")");
self.within_group.walk_ast(pass.reborrow())?;
self.filter.walk_ast(pass.reborrow())?;
self.window.walk_ast(pass.reborrow())?;
Ok(())
}
}

impl<Fn, Prefix, Order, Filter, Within, GB> ValidGrouping<GB>
for AggregateExpression<Fn, Prefix, Order, Filter, Within>
where
Fn: ValidGrouping<GB>,
{
type IsAggregate = <Fn as ValidGrouping<GB>>::IsAggregate;
}

impl<Fn, Prefix, Order, Filter, Within, GB, Partition, WindowOrder, Frame> ValidGrouping<GB>
for AggregateExpression<
Fn,
Prefix,
Order,
Filter,
Within,
OverClause<Partition, WindowOrder, Frame>,
>
where
Fn: ValidGrouping<GB>,
{
// not sure about that, check this
type IsAggregate = is_aggregate::No;
}

impl<Fn, Prefix, Order, Filter, Within, Window> Expression
for AggregateExpression<Fn, Prefix, Order, Filter, Within, Window>
where
Fn: Expression,
{
type SqlType = <Fn as Expression>::SqlType;
}

impl<Fn, Prefix, Order, Filter, Within, Window, QS> AppearsOnTable<QS>
for AggregateExpression<Fn, Prefix, Order, Filter, Within, Window>
where
Self: Expression,
Fn: AppearsOnTable<QS>,
{
}

impl<Fn, Prefix, Order, Filter, Within, Window, QS> SelectableExpression<QS>
for AggregateExpression<Fn, Prefix, Order, Filter, Within, Window>
where
Self: Expression,
Fn: SelectableExpression<QS>,
{
}

pub trait WindowFunction {}
pub trait AggregateFunction {}
pub trait FunctionFragment<DB: Backend> {
fn walk_name<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>;

fn walk_arguments<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>;
}

// TODO: write helper types for all functions
pub trait AggregateExpressionMethods: Sized {
fn distinct(self) -> Self::Output
where
Self: DistinctDsl,
{
<Self as DistinctDsl>::distinct(self)
}

fn all(self) -> Self::Output
where
Self: AllDsl,
{
<Self as AllDsl>::all(self)
}

// todo: do we want `or_filter` as well?
fn filter_aggregate<P>(self, f: P) -> Self::Output
where
P: AsExpression<Bool>,
Self: FilterDsl<P::Expression>,
{
<Self as FilterDsl<P::Expression>>::filter(self, f.as_expression())
}

fn order_aggregate<O>(self, o: O) -> Self::Output
where
Self: OrderAggregateDsl<O>,
{
<Self as OrderAggregateDsl<O>>::order(self, o)
}

// todo: restrict this to order set aggregates
// (we don't have any in diesel yet)
fn within_group<O>(self, o: O) -> Self::Output
where
Self: WithinGroupDsl<O>,
{
<Self as WithinGroupDsl<O>>::within_group(self, o)
}
}

impl<T> AggregateExpressionMethods for T {}

pub trait WindowExpressionMethods: Sized {
fn over(self) -> Self::Output
where
Self: OverDsl,
{
<Self as OverDsl>::over(self)
}

// todo: do we want `or_filter` as well?
fn filter_window<P>(self, f: P) -> Self::Output
where
P: AsExpression<Bool>,
Self: FilterDsl<P::Expression>,
{
<Self as FilterDsl<P::Expression>>::filter(self, f.as_expression())
}

fn partition_by<E>(self, expr: E) -> Self::Output
where
Self: PartitionByDsl<E>,
{
<Self as PartitionByDsl<E>>::partition_by(self, expr)
}

fn window_order<E>(self, expr: E) -> Self::Output
where
Self: OrderWindowDsl<E>,
{
<Self as OrderWindowDsl<E>>::order(self, expr)
}

fn frame_by<E>(self, expr: E) -> Self::Output
where
Self: FrameDsl<E>,
{
<Self as FrameDsl<E>>::frame(self, expr)
}
}

impl<T> WindowExpressionMethods for T {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use super::aggregate_order::NoOrder;
use super::prefix::NoPrefix;
use super::AggregateExpression;
use super::AggregateFunction;
use super::NoWindow;
use super::NoWithin;
use crate::query_builder::where_clause::NoWhereClause;
use crate::query_builder::where_clause::WhereAnd;
use crate::query_builder::QueryFragment;
use crate::query_builder::{AstPass, QueryId};
use crate::sql_types::Bool;
use crate::Expression;
use crate::QueryResult;

empty_clause!(NoFilter);

#[derive(QueryId, Copy, Clone)]
pub struct Filter<P>(P);

impl<P> QueryFragment<diesel::pg::Pg> for Filter<P>
where
P: QueryFragment<diesel::pg::Pg>,
{
fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> {
pass.push_sql(" FILTER (");
self.0.walk_ast(pass.reborrow())?;
pass.push_sql(")");
Ok(())
}
}

pub trait FilterDsl<P> {
type Output;

fn filter(self, f: P) -> Self::Output;
}

impl<P, T> FilterDsl<P> for T
where
T: AggregateFunction,
// todo: allow nullable bools here
P: Expression<SqlType = Bool>,
{
type Output =
AggregateExpression<T, NoPrefix, NoOrder, Filter<<NoWhereClause as WhereAnd<P>>::Output>>;

fn filter(self, f: P) -> Self::Output {
AggregateExpression {
prefix: NoPrefix,
function: self,
order: NoOrder,
filter: Filter(NoWhereClause.and(f)),
within_group: NoWithin,
window: NoWindow,
}
}
}

impl<Fn, P, Prefix, Order, F, Within, Window> FilterDsl<P>
for AggregateExpression<Fn, Prefix, Order, Filter<F>, Within, Window>
where
// todo: allow nullable bools here
F: WhereAnd<P>,
{
type Output =
AggregateExpression<Fn, Prefix, Order, Filter<<F as WhereAnd<P>>::Output>, Within, Window>;

fn filter(self, f: P) -> Self::Output {
AggregateExpression {
prefix: self.prefix,
function: self.function,
order: self.order,
filter: Filter(WhereAnd::<P>::and(self.filter.0, f)),
within_group: self.within_group,
window: self.window,
}
}
}

impl<Fn, P, Prefix, Order, Within, Window> FilterDsl<P>
for AggregateExpression<Fn, Prefix, Order, NoFilter, Within, Window>
where
// todo: allow nullable bools here
NoWhereClause: WhereAnd<P>,
{
type Output = AggregateExpression<
Fn,
Prefix,
Order,
Filter<<NoWhereClause as WhereAnd<P>>::Output>,
Within,
Window,
>;

fn filter(self, f: P) -> Self::Output {
AggregateExpression {
prefix: self.prefix,
function: self.function,
order: self.order,
filter: Filter(WhereAnd::<P>::and(NoWhereClause, f)),
within_group: self.within_group,
window: self.window,
}
}
}
Loading
Loading