Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug preventing orderedAggregate and distinctAggregator from being used together (alternative approach) #578

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import qualified Data.Aeson as Json
import qualified Data.Function as F
import Data.Int (Int32)
import qualified Data.List as L
import qualified Data.List.NonEmpty as NE
import qualified Data.Ord as Ord
import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP
Expand Down Expand Up @@ -617,6 +618,21 @@ testStringArrayAggregateOrdered = it "" $ q `selectShouldReturnSorted` expected
]
sortedData = L.sortBy (Ord.comparing snd) table7data


testStringArrayAggregateOrderedDistinct :: Test
testStringArrayAggregateOrderedDistinct = it "" $ q `selectShouldReturnSorted` expected
where q =
O.aggregateOrdered
(O.asc snd)
(PP.p2 (O.arrayAgg, O.distinctAggregator . O.stringAgg . O.sqlString $ ","))
table7Q
expected = [ ( map fst sortedData
, L.intercalate "," $ map NE.head $ NE.group $ map snd sortedData
)
]
sortedData = L.sortBy (Ord.comparing snd) table7data


-- | Using orderAggregate you can apply different orderings to
-- different aggregates.

Expand Down Expand Up @@ -1523,6 +1539,7 @@ main = do
testOverwriteAggregateOrdered
testMultipleAggregateOrdered
testStringArrayAggregateOrdered
testStringArrayAggregateOrderedDistinct
testDistinctAndAggregate
testDoubleAggregate
describe "distinct" $ do
Expand Down
3 changes: 2 additions & 1 deletion opaleye.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ library
aeson >= 0.6 && < 2.3
, base >= 4.9 && < 4.20
, base16-bytestring >= 0.1.1.6 && < 1.1
, case-insensitive >= 1.2 && < 1.3
, bytestring >= 0.10 && < 0.12
, case-insensitive >= 1.2 && < 1.3
, containers >= 0.5 && < 0.8
, contravariant >= 1.2 && < 1.6
, postgresql-simple >= 0.6 && < 0.8
, pretty >= 1.1.1.0 && < 1.2
Expand Down
78 changes: 53 additions & 25 deletions src/Opaleye/Internal/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
module Opaleye.Internal.Aggregate where

import Control.Applicative (liftA2)
import Control.Arrow ((***))
import Data.Foldable (toList)

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Foldable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Foldable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Foldable’ is redundant

Check warning on line 6 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Foldable’ is redundant
import Data.Traversable (for)

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.2)

The import of ‘Data.Traversable’ is redundant

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.10)

The import of ‘Data.Traversable’ is redundant

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 9.0)

The import of ‘Data.Traversable’ is redundant

Check warning on line 7 in src/Opaleye/Internal/Aggregate.hs

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest, 8.8)

The import of ‘Data.Traversable’ is redundant

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP

import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT, gets, modify, runStateT)

import qualified Opaleye.Field as F
import qualified Opaleye.Internal.Column as C
import qualified Opaleye.Internal.Order as O
Expand Down Expand Up @@ -130,42 +137,63 @@
--
-- Instead of detecting when we are aggregating over a field from a
-- previous query we just create new names for all field before we
-- aggregate. On the other hand, referring to a field from a previous
-- query in an ORDER BY expression is totally fine!
-- aggregate.
--
-- Additionally, PostgreSQL imposes a limitation on aggregations using ORDER
-- BY in combination with DISTINCT - essentially the expression you pass to
-- ORDER BY must also be present in the argument list to the aggregation
-- function. This means that not only do we also have to also create new
-- names for the ORDER BY expressions (if we only rewrite the function
-- arguments then they can't match and therefore ORDER BY can never be used
-- with DISTINCT), but that these names actually have to match the names
-- created for the aggregation function arguments. To accomplish this, when
-- traversing over the aggregations, we keep track of all the expressions
-- we've encountered so far, and only create new names for new expressions,
-- reusing old names where possible.
aggregateU :: Aggregator a b
-> (a, PQ.PrimQuery, T.Tag) -> (b, PQ.PrimQuery)
aggregateU agg (c0, primQ, t0) = (c1, primQ')
where (c1, projPEs_inners) =
PM.run (runAggregator agg (extractAggregateFields t0) c0)
aggregateU agg (a, primQ, tag) = (b, primQ')
where
(inners, outers, b) =
runSymbols (runAggregator agg (extractAggregateFields tag) a)

projPEs = map fst projPEs_inners
inners = concatMap snd projPEs_inners
inners' = fmap (fmap HPQ.AttrExpr) inners

primQ' = PQ.Aggregate projPEs (PQ.Rebind True inners primQ)
primQ' = PQ.Aggregate outers (PQ.Rebind True inners' primQ)

extractAggregateFields
:: Traversable t
=> T.Tag
-> (t HPQ.PrimExpr)
-> PM.PM [((HPQ.Symbol,
t HPQ.Symbol),
PQ.Bindings HPQ.PrimExpr)]
HPQ.PrimExpr
-> t HPQ.PrimExpr
-> Symbols HPQ.Symbol (PQ.Bindings (t HPQ.PrimExpr)) HPQ.PrimExpr
extractAggregateFields tag agg = do
i <- PM.new

let souter = HPQ.Symbol ("result" ++ i) tag

bindings <- for agg $ \pe -> do
j <- PM.new
let sinner = HPQ.Symbol ("inner" ++ j) tag
pure (sinner, pe)

let agg' = fmap fst bindings
result <- mkSymbol "result" <$> lift PM.new
agg' <- traverse (HPQ.traverseSymbols (symbolize (mkSymbol "inner"))) agg
lift $ PM.write (result, agg')
pure $ HPQ.AttrExpr result
where
mkSymbol name i = HPQ.Symbol (name ++ i) tag

PM.write ((souter, agg'), toList bindings)
type Symbols e s =
StateT
(Map e HPQ.Symbol, PQ.Bindings e -> PQ.Bindings e)
(PM.PM s)

pure (HPQ.AttrExpr souter)
runSymbols :: Symbols e [s] a -> (PQ.Bindings e, [s], a)
runSymbols m = (dlist [], outers, a)
where
((a, (_, dlist)), outers) = PM.run $ runStateT m (Map.empty, id)

symbolize :: Ord e =>
(String -> HPQ.Symbol) -> e -> Symbols e s HPQ.Symbol
symbolize f expr = do
msymbol <- gets (Map.lookup expr . fst)
case msymbol of
Just symbol -> pure symbol
Nothing -> do
symbol <- f <$> lift PM.new
modify (Map.insert expr symbol *** (. ((symbol, expr) :)))
pure symbol

unsafeMax :: Aggregator (C.Field a) (C.Field a)
unsafeMax = makeAggr HPQ.AggrMax
Expand Down
64 changes: 47 additions & 17 deletions src/Opaleye/Internal/HaskellDB/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
-- License : BSD-style

{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE LambdaCase #-}

module Opaleye.Internal.HaskellDB.PrimQuery where

Expand All @@ -17,7 +18,7 @@ type Name = String
type Scheme = [Attribute]
type Assoc = [(Attribute,PrimExpr)]

data Symbol = Symbol String T.Tag deriving (Read, Show)
data Symbol = Symbol String T.Tag deriving (Eq, Ord, Read, Show)

data PrimExpr = AttrExpr Symbol
| BaseTableAttrExpr Attribute
Expand All @@ -42,6 +43,29 @@ data PrimExpr = AttrExpr Symbol
| ArrayIndex PrimExpr PrimExpr
deriving (Read,Show)

traverseSymbols :: Applicative f => (Symbol -> f Symbol) -> PrimExpr -> f PrimExpr
traverseSymbols f = go
where
go = \case
AttrExpr symbol -> AttrExpr <$> f symbol
BaseTableAttrExpr attribute -> pure $ BaseTableAttrExpr attribute
CompositeExpr a attribute -> CompositeExpr <$> go a <*> pure attribute
BinExpr op a b -> BinExpr op <$> go a <*> go b
UnExpr op a -> UnExpr op <$> go a
AggrExpr aggr -> AggrExpr <$> traverse go aggr
WndwExpr wndw partition -> WndwExpr <$> traverse go wndw <*> traverse go partition
ConstExpr literal -> pure $ ConstExpr literal
CaseExpr conds a -> CaseExpr <$> traverse (bitraverse go go) conds <*> go a
ListExpr as -> ListExpr <$> traverse go as
ParamExpr name a -> ParamExpr name <$> go a
FunExpr name args -> FunExpr name <$> traverse go args
CastExpr name a -> CastExpr name <$> go a
DefaultInsertExpr -> pure DefaultInsertExpr
ArrayExpr as -> ArrayExpr <$> traverse go as
RangeExpr s a b -> RangeExpr s <$> traverse go a <*> traverse go b
ArrayIndex a b -> ArrayIndex <$> go a <*> go b
bitraverse g h (a, b) = (,) <$> g a <*> h b

data Literal = NullLit
| DefaultLit -- ^ represents a default value
| BoolLit Bool
Expand Down Expand Up @@ -119,26 +143,32 @@ data OrderOp = OrderOp { orderDirection :: OrderDirection
, orderNulls :: OrderNulls }
deriving (Show,Read)

data BoundExpr = Inclusive PrimExpr | Exclusive PrimExpr | PosInfinity | NegInfinity
deriving (Show,Read)
type BoundExpr = BoundExpr' PrimExpr

data BoundExpr' a = Inclusive a | Exclusive a | PosInfinity | NegInfinity
deriving (Foldable, Functor, Traversable, Read, Show)

type WndwOp = WndwOp' PrimExpr

data WndwOp
data WndwOp' a
= WndwRowNumber
| WndwRank
| WndwDenseRank
| WndwPercentRank
| WndwCumeDist
| WndwNtile PrimExpr
| WndwLag PrimExpr PrimExpr PrimExpr
| WndwLead PrimExpr PrimExpr PrimExpr
| WndwFirstValue PrimExpr
| WndwLastValue PrimExpr
| WndwNthValue PrimExpr PrimExpr
| WndwAggregate AggrOp [PrimExpr]
deriving (Show,Read)

data Partition = Partition
{ partitionBy :: [PrimExpr]
, orderBy :: [OrderExpr]
| WndwNtile a
| WndwLag a a a
| WndwLead a a a
| WndwFirstValue a
| WndwLastValue a
| WndwNthValue a a
| WndwAggregate AggrOp [a]
deriving (Foldable, Functor, Traversable, Show, Read)

type Partition = Partition' PrimExpr

data Partition' a = Partition
{ partitionBy :: [a]
, orderBy :: [OrderExpr' a]
}
deriving (Read, Show)
deriving (Foldable, Functor, Traversable, Read, Show)
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ data PrimQuery' a = Unit
| Product (NEL.NonEmpty (Lateral, PrimQuery' a)) [HPQ.PrimExpr]
-- | The subqueries to take the product of and the
-- restrictions to apply
| Aggregate (Bindings (HPQ.Aggregate' HPQ.Symbol))
| Aggregate (Bindings (HPQ.Aggregate))
(PrimQuery' a)
| Window (Bindings (HPQ.WndwOp, HPQ.Partition)) (PrimQuery' a)
-- | Represents both @DISTINCT ON@ and @ORDER BY@
Expand Down Expand Up @@ -178,7 +178,7 @@ data PrimQueryFoldP a p p' = PrimQueryFold
, empty :: a -> p'
, baseTable :: TableIdentifier -> Bindings HPQ.PrimExpr -> p'
, product :: NEL.NonEmpty (Lateral, p) -> [HPQ.PrimExpr] -> p'
, aggregate :: Bindings (HPQ.Aggregate' HPQ.Symbol)
, aggregate :: Bindings HPQ.Aggregate
-> p
-> p'
, window :: Bindings (HPQ.WndwOp, HPQ.Partition) -> p -> p'
Expand Down
7 changes: 2 additions & 5 deletions src/Opaleye/Internal/Sql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ product ss pes = SelectFrom $
PQ.Lateral -> Lateral
PQ.NonLateral -> NonLateral

aggregate :: PQ.Bindings (HPQ.Aggregate' HPQ.Symbol)
aggregate :: PQ.Bindings HPQ.Aggregate
-> Select
-> Select
aggregate aggrs' s =
aggregate aggrs s =
SelectFrom $ newSelect { attrs = SelectAttrs (ensureColumns (map attr aggrs))
, tables = oneTable s
, groupBy = Just (groupBy' aggrs) }
Expand All @@ -190,9 +190,6 @@ aggregate aggrs' s =
handleEmpty :: [HSql.SqlExpr] -> NEL.NonEmpty HSql.SqlExpr
handleEmpty = ensureColumnsGen SP.deliteral

aggrs :: [(Symbol, HPQ.Aggregate)]
aggrs = (map . Arr.second . fmap) HPQ.AttrExpr aggrs'

groupBy' :: [(symbol, HPQ.Aggregate)]
-> NEL.NonEmpty HSql.SqlExpr
groupBy' aggs = handleEmpty $ do
Expand Down
2 changes: 1 addition & 1 deletion src/Opaleye/Internal/Tag.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Opaleye.Internal.Tag where
import Control.Monad.Trans.State.Strict ( get, modify', State )

-- | Tag is for use as a source of unique IDs in QueryArr
newtype Tag = UnsafeTag Int deriving (Read, Show)
newtype Tag = UnsafeTag Int deriving (Eq, Ord, Read, Show)

start :: Tag
start = UnsafeTag 1
Expand Down
Loading