Skip to content

Commit

Permalink
Add ordered set aggregation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
shane-circuithub committed Oct 11, 2023
1 parent d0ba116 commit fb96ffc
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 27 deletions.
3 changes: 3 additions & 0 deletions changelog.d/20231009_170616_shane.obrien_mode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Added

- Add support for ordered-set aggregation functions, including `mode`, `percentile`, `percentileContinuous`, `hypotheticalRank`, `hypotheticalDenseRank`, `hypotheticalPercentRank` and `hypotheticalCumeDist`.
9 changes: 8 additions & 1 deletion src/Rel8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ module Rel8
, groupBy, groupByOn
, listAgg, listAggOn, listAggExpr, listAggExprOn
, listCat, listCatOn, listCatExpr, listCatExprOn
, mode
, nonEmptyAgg, nonEmptyAggOn, nonEmptyAggExpr, nonEmptyAggExprOn
, nonEmptyCat, nonEmptyCatOn, nonEmptyCatExpr, nonEmptyCatExprOn
, DBMax, max, maxOn
Expand All @@ -295,6 +294,14 @@ module Rel8
, and, andOn
, or, orOn

, mode, modeOn
, percentile, percentileOn
, percentileContinuous, percentileContinuousOn
, hypotheticalRank
, hypotheticalDenseRank
, hypotheticalPercentRank
, hypotheticalCumeDist

-- ** Ordering
, orderBy
, Order
Expand Down
147 changes: 144 additions & 3 deletions src/Rel8/Expr/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{-# language NamedFieldPuns #-}
{-# language OverloadedStrings #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeFamilies #-}

{-# options_ghc -fno-warn-redundant-constraints #-}
Expand All @@ -17,6 +18,13 @@ module Rel8.Expr.Aggregate
, sum, sumOn, sumWhere
, avg, avgOn
, stringAgg, stringAggOn
, mode, modeOn
, percentile, percentileOn
, percentileContinuous, percentileContinuousOn
, hypotheticalRank
, hypotheticalDenseRank
, hypotheticalPercentRank
, hypotheticalCumeDist
, groupByExpr, groupByExprOn
, distinctAggregate
, filterWhereExplicit
Expand All @@ -28,6 +36,7 @@ module Rel8.Expr.Aggregate
where

-- base
import Data.Functor.Contravariant ((>$<))
import Data.Int ( Int64 )
import Data.List.NonEmpty ( NonEmpty )
import Data.String (IsString)
Expand All @@ -36,6 +45,7 @@ import Prelude hiding (and, max, min, null, or, show, sum)
-- opaleye
import qualified Opaleye.Aggregate as Opaleye
import qualified Opaleye.Internal.Aggregate as Opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.Operators as Opaleye

-- profunctors
Expand All @@ -59,17 +69,22 @@ import Rel8.Expr.Opaleye
, fromPrimExpr
, toColumn
, toPrimExpr
, unsafeCastExpr
)
import Rel8.Expr.Order (asc)
import Rel8.Expr.Read (sread)
import Rel8.Expr.Show (show)
import qualified Rel8.Expr.Text as Text
import Rel8.Order (Order (Order))
import Rel8.Schema.Null ( Sql, Unnullify )
import Rel8.Table.Opaleye (fromOrder, unpackspec)
import Rel8.Table.Order (ascTable)
import Rel8.Type ( DBType, typeInformation )
import Rel8.Type.Array (arrayTypeName, encodeArrayElement)
import Rel8.Type.Eq ( DBEq )
import Rel8.Type.Information (TypeInformation)
import Rel8.Type.Num ( DBNum )
import Rel8.Type.Ord ( DBMax, DBMin )
import Rel8.Type.Num (DBFractional, DBNum)
import Rel8.Type.Ord (DBMax, DBMin, DBOrd)
import Rel8.Type.String ( DBString )
import Rel8.Type.Sum ( DBSum )

Expand Down Expand Up @@ -239,6 +254,132 @@ stringAggOn :: (Sql IsString a, Sql DBString a)
stringAggOn delimiter f = lmap f (stringAgg delimiter)


-- | Corresponds to @mode() WITHIN GROUP (ORDER BY _)@.
mode :: Sql DBOrd a => Aggregator1 (Expr a) (Expr a)
mode =
unsafeMakeAggregator
id
(fromPrimExpr . fromColumn)
Empty
(Opaleye.withinGroup ((\(Order o) -> o) ascTable)
(Opaleye.makeAggrExplicit (pure ()) (Opaleye.AggrOther "mode")))


-- | Applies 'mode' to the column selected by the given function.
modeOn :: Sql DBOrd a => (i -> Expr a) -> Aggregator1 i (Expr a)
modeOn f = lmap f mode


-- | Corresponds to @percentile_disc(_) WITHIN GROUP (ORDER BY _)@.
percentile :: Sql DBOrd a => Expr Double -> Aggregator1 (Expr a) (Expr a)
percentile fraction =
unsafeMakeAggregator
(\a -> (fraction, a))
(castExpr . fromPrimExpr . fromColumn)
Empty
(Opaleye.withinGroup ((\(Order o) -> o) (snd >$< ascTable))
(Opaleye.makeAggrExplicit
(lmap fst unpackspec)
(Opaleye.AggrOther "percentile_disc")))


-- | Applies 'percentile' to the column selected by the given function.
percentileOn ::
Sql DBOrd a =>
Expr Double ->
(i -> Expr a) ->
Aggregator1 i (Expr a)
percentileOn fraction f = lmap f (percentile fraction)


-- | Corresponds to @percentile_cont(_) WITHIN GROUP (ORDER BY _)@.
percentileContinuous ::
Sql DBFractional a =>
Expr Double ->
Aggregator1 (Expr a) (Expr a)
percentileContinuous fraction =
unsafeMakeAggregator
(\a -> (fraction, a))
(castExpr . fromPrimExpr . fromColumn)
Empty
(Opaleye.withinGroup ((\(Order o) -> o) (unsafeCastExpr @Double . snd >$< asc))
(Opaleye.makeAggrExplicit
(lmap fst unpackspec)
(Opaleye.AggrOther "percentile_disc")))



-- | Applies 'percentileContinuous' to the column selected by the given
-- function.
percentileContinuousOn ::
Sql DBFractional a =>
Expr Double ->
(i -> Expr a) ->
Aggregator1 i (Expr a)
percentileContinuousOn fraction f = lmap f (percentileContinuous fraction)


-- | Corresponds to @rank(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalRank ::
Order a ->
a ->
Aggregator' fold a (Expr Int64)
hypotheticalRank (Order order) args =
unsafeMakeAggregator
(\a -> (args, a))
(castExpr . fromPrimExpr . fromColumn)
(Fallback 1)
(Opaleye.withinGroup (snd >$< order)
(Opaleye.makeAggrExplicit
(fromOrder (fst >$< order))
(Opaleye.AggrOther "rank")))


-- | Corresponds to @dense_rank(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalDenseRank ::
Order a ->
a ->
Aggregator' fold a (Expr Int64)
hypotheticalDenseRank (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 1)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "dense_rank")))


-- | Corresponds to @percent_rank(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalPercentRank ::
Order a ->
a ->
Aggregator' fold a (Expr Double)
hypotheticalPercentRank (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 0)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "percent_rank")))


-- | Corresponds to @cume_dist(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalCumeDist ::
Order a ->
a ->
Aggregator' fold a (Expr Double)
hypotheticalCumeDist (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 1)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "cume_dist")))


-- | Aggregate a value by grouping by it.
groupByExpr :: Sql DBEq a => Aggregator1 (Expr a) (Expr a)
groupByExpr =
Expand All @@ -249,7 +390,7 @@ groupByExpr =
Opaleye.groupBy


-- | Applies 'groupByExprOn' to the column selected by the given function.
-- | Applies 'groupByExpr' to the column selected by the given function.
groupByExprOn :: Sql DBEq a => (i -> Expr a) -> Aggregator1 i (Expr a)
groupByExprOn f = lmap f groupByExpr

Expand Down
2 changes: 1 addition & 1 deletion src/Rel8/Expr/Num.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fromIntegral :: (Sql DBIntegral a, Sql DBNum b, Homonullable a b)
fromIntegral (Expr a) = castExpr (Expr a)


-- | Cast 'DBNum' types to 'DBFractional' types. For example, his can be useful
-- | Cast 'DBNum' types to 'DBFractional' types. For example, this can be useful
-- to convert @Expr Float@ to @Expr Double@.
realToFrac :: (Sql DBNum a, Sql DBFractional b, Homonullable a b)
=> Expr a -> Expr b
Expand Down
15 changes: 0 additions & 15 deletions src/Rel8/Query/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ module Rel8.Query.Aggregate
( aggregate
, aggregate1
, countRows
, mode
)
where

-- base
import Control.Applicative (liftA2)
import Data.Functor.Contravariant ( (>$<) )
import Data.Int ( Int64 )
import Prelude

Expand All @@ -24,15 +22,10 @@ import Rel8.Aggregate (Aggregator' (Aggregator), Aggregator)
import Rel8.Aggregate.Fold (Fallback (Fallback))
import Rel8.Expr ( Expr )
import Rel8.Expr.Aggregate ( countStar )
import Rel8.Expr.Order ( desc )
import Rel8.Query ( Query )
import Rel8.Query.Limit ( limit )
import Rel8.Query.Maybe ( optional )
import Rel8.Query.Opaleye ( mapOpaleye )
import Rel8.Query.Order ( orderBy )
import Rel8.Table (Table)
import Rel8.Table.Aggregate (groupBy)
import Rel8.Table.Eq (EqTable)
import Rel8.Table.Maybe (fromMaybeTable)


Expand All @@ -55,11 +48,3 @@ aggregate1 (Aggregator _ aggregator) = mapOpaleye (Opaleye.aggregate aggregator)
-- will return @0@.
countRows :: Query a -> Query (Expr Int64)
countRows = aggregate countStar


-- | Return the most common row in a query.
mode :: forall a. EqTable a => Query a -> Query a
mode rows =
limit 1 $ fmap snd $
orderBy (fst >$< desc) $ do
aggregate1 (liftA2 (,) countStar groupBy) rows
24 changes: 20 additions & 4 deletions src/Rel8/Schema/HTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

module Rel8.Schema.HTable
( HTable (HField, HConstrainTable)
, hfield, htabulate, htraverse, hdicts, hspecs
, hfoldMap, hmap, htabulateA, htabulateP, htraverseP, htraversePWithField
, hfield, htabulate, hdicts, hspecs
, hfoldMap, hmap, htabulateA, htabulateP
, htraverse, htraverse_, htraverseP, htraversePWithField
)
where

-- base
import Data.Functor (void)
import Data.Functor.Compose ( Compose( Compose ), getCompose )
import Data.Functor.Const ( Const( Const ), getConst )
import Data.Kind ( Constraint, Type )
import Data.Functor.Compose ( Compose( Compose ), getCompose )
import Data.Proxy ( Proxy )
import GHC.Generics
( (:*:)( (:*:) )
Expand All @@ -46,7 +48,7 @@ import Rel8.Schema.HTable.Product ( HProduct( HProduct ) )
import qualified Rel8.Schema.Kind as K

-- semigroupoids
import Data.Functor.Apply ( Apply, (<.>) )
import Data.Functor.Apply (Apply, (<.>), liftF2)

-- | A @HTable@ is a functor-indexed/higher-kinded data type that is
-- representable ('htabulate'/'hfield'), constrainable ('hdicts'), and
Expand Down Expand Up @@ -130,6 +132,20 @@ hmap :: HTable t
hmap f a = htabulate $ \field -> f (hfield a field)


newtype Ap f a = Ap
{ getAp :: f a
}


instance (Apply f, Semigroup a) => Semigroup (Ap f a) where
Ap a <> Ap b = Ap (liftF2 (<>) a b)


htraverse_ :: (HTable t, Apply f)
=> (forall a. context a -> f b) -> t context -> f ()
htraverse_ f a = getAp $ hfoldMap (Ap . void . f) a


htabulateA :: (HTable t, Apply m)
=> (forall a. HField t a -> m (context a)) -> m (t context)
htabulateA f = htraverse getCompose $ htabulate $ Compose . f
Expand Down
17 changes: 15 additions & 2 deletions src/Rel8/Table/Opaleye.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ module Rel8.Table.Opaleye
, valuesspec
, view
, castTable
, fromOrder
)
where

-- base
import Data.Foldable (traverse_)
import Data.Functor.Const ( Const( Const ), getConst )
import Data.List.NonEmpty ( NonEmpty )
import Prelude
Expand All @@ -36,6 +38,9 @@ import qualified Opaleye.Adaptors as Opaleye
import qualified Opaleye.Field as Opaleye ( Field_ )
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.Operators as Opaleye
import qualified Opaleye.Internal.Order as Opaleye
import qualified Opaleye.Internal.PackMap as Opaleye
import qualified Opaleye.Internal.Unpackspec as Opaleye
import qualified Opaleye.Internal.Values as Opaleye
import qualified Opaleye.Table as Opaleye

Expand All @@ -48,8 +53,10 @@ import Rel8.Expr.Opaleye
( fromPrimExpr, toPrimExpr
, scastExpr, traverseFieldP
)
import Rel8.Schema.HTable ( htabulateA, hfield, hspecs, htabulate,
htraverseP, htraversePWithField )
import Rel8.Schema.HTable
( htabulateA, hfield, hspecs, htabulate
, htraverseP, htraversePWithField
)
import Rel8.Schema.Name ( Name( Name ), Selects, ppColumn )
import Rel8.Schema.QualifiedName (QualifiedName (QualifiedName))
import Rel8.Schema.Spec ( Spec(..) )
Expand Down Expand Up @@ -153,3 +160,9 @@ castTable (toColumns -> as) = fromColumns $ htabulate \field ->
case hfield hspecs field of
Spec {info} -> case hfield as field of
expr -> scastExpr info expr


fromOrder :: Opaleye.Order a -> Opaleye.Unpackspec a a
fromOrder (Opaleye.Order o) =
Opaleye.Unpackspec $ Opaleye.PackMap $ \f a ->
a <$ traverse_ (f . snd) (o a)
Loading

0 comments on commit fb96ffc

Please sign in to comment.