Skip to content

Commit

Permalink
Detect identical references in extractAggregateFields to allow `dis…
Browse files Browse the repository at this point in the history
…tinctAggregator` to be used with `orderAggregate`
  • Loading branch information
shane-circuithub committed Oct 11, 2023
1 parent dc7ab10 commit ce8015a
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 51 deletions.
3 changes: 2 additions & 1 deletion opaleye.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ library
aeson >= 0.6 && < 2.3
, base >= 4.9 && < 4.19
, base16-bytestring >= 0.1.1.6 && < 1.1
, case-insensitive >= 1.2 && < 1.3
, bytestring >= 0.10 && < 0.12
, case-insensitive >= 1.2 && < 1.3
, containers >= 0.5 && < 0.8
, contravariant >= 1.2 && < 1.6
, postgresql-simple >= 0.6 && < 0.8
, pretty >= 1.1.1.0 && < 1.2
Expand Down
78 changes: 53 additions & 25 deletions src/Opaleye/Internal/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
module Opaleye.Internal.Aggregate where

import Control.Applicative (liftA2)
import Control.Arrow ((***))
import Data.Foldable (toList)

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Foldable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Foldable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Foldable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Foldable’ is redundant
import Data.Traversable (for)

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Traversable’ is redundant

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Traversable’ is redundant

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Traversable’ is redundant

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Traversable’ is redundant

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP

import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT, gets, modify, runStateT)

import qualified Opaleye.Field as F
import qualified Opaleye.Internal.Column as C
import qualified Opaleye.Internal.Order as O
Expand Down Expand Up @@ -130,42 +137,63 @@ aggregatorApply = Aggregator $ PM.PackMap $ \f (agg, a) ->
--
-- Instead of detecting when we are aggregating over a field from a
-- previous query we just create new names for all field before we
-- aggregate. On the other hand, referring to a field from a previous
-- query in an ORDER BY expression is totally fine!
-- aggregate.
--
-- Additionally, PostgreSQL imposes a limitation on aggregations using ORDER
-- BY in combination with DISTINCT - essentially the expression you pass to
-- ORDER BY must also be present in the argument list to the aggregation
-- function. This means that not only do we also have to also create new
-- names for the ORDER BY expressions (if we only rewrite the function
-- arguments then they can't match and therefore ORDER BY can never be used
-- with DISTINCT), but that these names actually have to match the names
-- created for the aggregation function arguments. To accomplish this, when
-- traversing over the aggregations, we keep track of all the expressions
-- we've encountered so far, and only create new names for new expressions,
-- reusing old names where possible.
aggregateU :: Aggregator a b
-> (a, PQ.PrimQuery, T.Tag) -> (b, PQ.PrimQuery)
aggregateU agg (c0, primQ, t0) = (c1, primQ')
where (c1, projPEs_inners) =
PM.run (runAggregator agg (extractAggregateFields t0) c0)
aggregateU agg (a, primQ, tag) = (b, primQ')
where
(inners, outers, b) =
runSymbols (runAggregator agg (extractAggregateFields tag) a)

projPEs = map fst projPEs_inners
inners = concatMap snd projPEs_inners
inners' = fmap (fmap HPQ.AttrExpr) inners

primQ' = PQ.Aggregate projPEs (PQ.Rebind True inners primQ)
primQ' = PQ.Aggregate outers (PQ.Rebind True inners' primQ)

extractAggregateFields
:: Traversable t
=> T.Tag
-> (t HPQ.PrimExpr)
-> PM.PM [((HPQ.Symbol,
t HPQ.Symbol),
PQ.Bindings HPQ.PrimExpr)]
HPQ.PrimExpr
-> t HPQ.PrimExpr
-> Symbols HPQ.Symbol (PQ.Bindings (t HPQ.PrimExpr)) HPQ.PrimExpr
extractAggregateFields tag agg = do
i <- PM.new

let souter = HPQ.Symbol ("result" ++ i) tag

bindings <- for agg $ \pe -> do
j <- PM.new
let sinner = HPQ.Symbol ("inner" ++ j) tag
pure (sinner, pe)

let agg' = fmap fst bindings
result <- mkSymbol "result" <$> lift PM.new
agg' <- traverse (HPQ.traverseSymbols (symbolize (mkSymbol "inner"))) agg
lift $ PM.write (result, agg')
pure $ HPQ.AttrExpr result
where
mkSymbol name i = HPQ.Symbol (name ++ i) tag

PM.write ((souter, agg'), toList bindings)
type Symbols e s =
StateT
(Map e HPQ.Symbol, PQ.Bindings e -> PQ.Bindings e)
(PM.PM s)

pure (HPQ.AttrExpr souter)
runSymbols :: Symbols e [s] a -> (PQ.Bindings e, [s], a)
runSymbols m = (dlist [], outers, a)
where
((a, (_, dlist)), outers) = PM.run $ runStateT m (Map.empty, id)

symbolize :: Ord e =>
(String -> HPQ.Symbol) -> e -> Symbols e s HPQ.Symbol
symbolize f expr = do
msymbol <- gets (Map.lookup expr . fst)
case msymbol of
Just symbol -> pure symbol
Nothing -> do
symbol <- f <$> lift PM.new
modify (Map.insert expr symbol *** (. ((symbol, expr) :)))
pure symbol

unsafeMax :: Aggregator (C.Field a) (C.Field a)
unsafeMax = makeAggr HPQ.AggrMax
Expand Down
64 changes: 47 additions & 17 deletions src/Opaleye/Internal/HaskellDB/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
-- License : BSD-style

{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE LambdaCase #-}

module Opaleye.Internal.HaskellDB.PrimQuery where

Expand All @@ -17,7 +18,7 @@ type Name = String
type Scheme = [Attribute]
type Assoc = [(Attribute,PrimExpr)]

data Symbol = Symbol String T.Tag deriving (Read, Show)
data Symbol = Symbol String T.Tag deriving (Eq, Ord, Read, Show)

data PrimExpr = AttrExpr Symbol
| BaseTableAttrExpr Attribute
Expand All @@ -42,6 +43,29 @@ data PrimExpr = AttrExpr Symbol
| ArrayIndex PrimExpr PrimExpr
deriving (Read,Show)

traverseSymbols :: Applicative f => (Symbol -> f Symbol) -> PrimExpr -> f PrimExpr
traverseSymbols f = go
where
go = \case
AttrExpr symbol -> AttrExpr <$> f symbol
BaseTableAttrExpr attribute -> pure $ BaseTableAttrExpr attribute
CompositeExpr a attribute -> CompositeExpr <$> go a <*> pure attribute
BinExpr op a b -> BinExpr op <$> go a <*> go b
UnExpr op a -> UnExpr op <$> go a
AggrExpr aggr -> AggrExpr <$> traverse go aggr
WndwExpr wndw partition -> WndwExpr <$> traverse go wndw <*> traverse go partition
ConstExpr literal -> pure $ ConstExpr literal
CaseExpr conds a -> CaseExpr <$> traverse (bitraverse go go) conds <*> go a
ListExpr as -> ListExpr <$> traverse go as
ParamExpr name a -> ParamExpr name <$> go a
FunExpr name args -> FunExpr name <$> traverse go args
CastExpr name a -> CastExpr name <$> go a
DefaultInsertExpr -> pure DefaultInsertExpr
ArrayExpr as -> ArrayExpr <$> traverse go as
RangeExpr s a b -> RangeExpr s <$> traverse go a <*> traverse go b
ArrayIndex a b -> ArrayIndex <$> go a <*> go b
bitraverse g h (a, b) = (,) <$> g a <*> h b

data Literal = NullLit
| DefaultLit -- ^ represents a default value
| BoolLit Bool
Expand Down Expand Up @@ -119,26 +143,32 @@ data OrderOp = OrderOp { orderDirection :: OrderDirection
, orderNulls :: OrderNulls }
deriving (Show,Read)

data BoundExpr = Inclusive PrimExpr | Exclusive PrimExpr | PosInfinity | NegInfinity
deriving (Show,Read)
type BoundExpr = BoundExpr' PrimExpr

data BoundExpr' a = Inclusive a | Exclusive a | PosInfinity | NegInfinity
deriving (Foldable, Functor, Traversable, Read, Show)

type WndwOp = WndwOp' PrimExpr

data WndwOp
data WndwOp' a
= WndwRowNumber
| WndwRank
| WndwDenseRank
| WndwPercentRank
| WndwCumeDist
| WndwNtile PrimExpr
| WndwLag PrimExpr PrimExpr PrimExpr
| WndwLead PrimExpr PrimExpr PrimExpr
| WndwFirstValue PrimExpr
| WndwLastValue PrimExpr
| WndwNthValue PrimExpr PrimExpr
| WndwAggregate AggrOp [PrimExpr]
deriving (Show,Read)

data Partition = Partition
{ partitionBy :: [PrimExpr]
, orderBy :: [OrderExpr]
| WndwNtile a
| WndwLag a a a
| WndwLead a a a
| WndwFirstValue a
| WndwLastValue a
| WndwNthValue a a
| WndwAggregate AggrOp [a]
deriving (Foldable, Functor, Traversable, Show, Read)

type Partition = Partition' PrimExpr

data Partition' a = Partition
{ partitionBy :: [a]
, orderBy :: [OrderExpr' a]
}
deriving (Read, Show)
deriving (Foldable, Functor, Traversable, Read, Show)
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ data PrimQuery' a = Unit
| Product (NEL.NonEmpty (Lateral, PrimQuery' a)) [HPQ.PrimExpr]
-- | The subqueries to take the product of and the
-- restrictions to apply
| Aggregate (Bindings (HPQ.Aggregate' HPQ.Symbol))
| Aggregate (Bindings (HPQ.Aggregate))
(PrimQuery' a)
| Window (Bindings (HPQ.WndwOp, HPQ.Partition)) (PrimQuery' a)
-- | Represents both @DISTINCT ON@ and @ORDER BY@
Expand Down Expand Up @@ -175,7 +175,7 @@ data PrimQueryFoldP a p p' = PrimQueryFold
, empty :: a -> p'
, baseTable :: TableIdentifier -> Bindings HPQ.PrimExpr -> p'
, product :: NEL.NonEmpty (Lateral, p) -> [HPQ.PrimExpr] -> p'
, aggregate :: Bindings (HPQ.Aggregate' HPQ.Symbol)
, aggregate :: Bindings HPQ.Aggregate
-> p
-> p'
, window :: Bindings (HPQ.WndwOp, HPQ.Partition) -> p -> p'
Expand Down
7 changes: 2 additions & 5 deletions src/Opaleye/Internal/Sql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ product ss pes = SelectFrom $
PQ.Lateral -> Lateral
PQ.NonLateral -> NonLateral

aggregate :: PQ.Bindings (HPQ.Aggregate' HPQ.Symbol)
aggregate :: PQ.Bindings HPQ.Aggregate
-> Select
-> Select
aggregate aggrs' s =
aggregate aggrs s =
SelectFrom $ newSelect { attrs = SelectAttrs (ensureColumns (map attr aggrs))
, tables = oneTable s
, groupBy = Just (groupBy' aggrs) }
Expand All @@ -188,9 +188,6 @@ aggregate aggrs' s =
handleEmpty :: [HSql.SqlExpr] -> NEL.NonEmpty HSql.SqlExpr
handleEmpty = ensureColumnsGen SP.deliteral

aggrs :: [(Symbol, HPQ.Aggregate)]
aggrs = (map . Arr.second . fmap) HPQ.AttrExpr aggrs'

groupBy' :: [(symbol, HPQ.Aggregate)]
-> NEL.NonEmpty HSql.SqlExpr
groupBy' aggs = handleEmpty $ do
Expand Down
2 changes: 1 addition & 1 deletion src/Opaleye/Internal/Tag.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Opaleye.Internal.Tag where
import Control.Monad.Trans.State.Strict ( get, modify', State )

-- | Tag is for use as a source of unique IDs in QueryArr
newtype Tag = UnsafeTag Int deriving (Read, Show)
newtype Tag = UnsafeTag Int deriving (Eq, Ord, Read, Show)

start :: Tag
start = UnsafeTag 1
Expand Down

0 comments on commit ce8015a

Please sign in to comment.