Skip to content

Commit

Permalink
Add ordered set aggregation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
shane-circuithub committed Oct 9, 2023
1 parent d0ba116 commit c061533
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 21 deletions.
9 changes: 8 additions & 1 deletion src/Rel8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ module Rel8
, groupBy, groupByOn
, listAgg, listAggOn, listAggExpr, listAggExprOn
, listCat, listCatOn, listCatExpr, listCatExprOn
, mode
, nonEmptyAgg, nonEmptyAggOn, nonEmptyAggExpr, nonEmptyAggExprOn
, nonEmptyCat, nonEmptyCatOn, nonEmptyCatExpr, nonEmptyCatExprOn
, DBMax, max, maxOn
Expand All @@ -295,6 +294,14 @@ module Rel8
, and, andOn
, or, orOn

, mode, modeOn
, percentile, percentileOn
, percentileContinuous, percentileContinuousOn
, hypotheticalRank
, hypotheticalDenseRank
, hypotheticalPercentRank
, hypotheticalCumeDist

-- ** Ordering
, orderBy
, Order
Expand Down
145 changes: 142 additions & 3 deletions src/Rel8/Expr/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ module Rel8.Expr.Aggregate
, sum, sumOn, sumWhere
, avg, avgOn
, stringAgg, stringAggOn
, mode, modeOn
, percentile, percentileOn
, percentileContinuous, percentileContinuousOn
, hypotheticalRank
, hypotheticalDenseRank
, hypotheticalPercentRank
, hypotheticalCumeDist
, groupByExpr, groupByExprOn
, distinctAggregate
, filterWhereExplicit
Expand All @@ -28,6 +35,7 @@ module Rel8.Expr.Aggregate
where

-- base
import Data.Functor.Contravariant ((>$<))
import Data.Int ( Int64 )
import Data.List.NonEmpty ( NonEmpty )
import Data.String (IsString)
Expand All @@ -36,6 +44,7 @@ import Prelude hiding (and, max, min, null, or, show, sum)
-- opaleye
import qualified Opaleye.Aggregate as Opaleye
import qualified Opaleye.Internal.Aggregate as Opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.Operators as Opaleye

-- profunctors
Expand All @@ -59,17 +68,22 @@ import Rel8.Expr.Opaleye
, fromPrimExpr
, toColumn
, toPrimExpr
, unsafeCastExpr
)
import Rel8.Expr.Order (asc)
import Rel8.Expr.Read (sread)
import Rel8.Expr.Show (show)
import qualified Rel8.Expr.Text as Text
import Rel8.Order (Order (Order))
import Rel8.Schema.Null ( Sql, Unnullify )
import Rel8.Table.Opaleye (fromOrder, unpackspec)
import Rel8.Table.Order (ascTable)
import Rel8.Type ( DBType, typeInformation )
import Rel8.Type.Array (arrayTypeName, encodeArrayElement)
import Rel8.Type.Eq ( DBEq )
import Rel8.Type.Information (TypeInformation)
import Rel8.Type.Num ( DBNum )
import Rel8.Type.Ord ( DBMax, DBMin )
import Rel8.Type.Num (DBFractional, DBNum)
import Rel8.Type.Ord (DBMax, DBMin, DBOrd)
import Rel8.Type.String ( DBString )
import Rel8.Type.Sum ( DBSum )

Expand Down Expand Up @@ -239,6 +253,131 @@ stringAggOn :: (Sql IsString a, Sql DBString a)
stringAggOn delimiter f = lmap f (stringAgg delimiter)


-- | Corresponds to @mode() WITHIN GROUP (ORDER BY _)@.
mode :: Sql DBOrd a => Aggregator1 (Expr a) (Expr a)
mode =
unsafeMakeAggregator
id
(fromPrimExpr . fromColumn)
Empty
(Opaleye.withinGroup ((\(Order o) -> o) ascTable)
(lmap mempty
(Opaleye.makeAggrExplicit (pure ()) (Opaleye.AggrOther "mode"))))


-- | Applies 'mode' to the column selected by the given function.
modeOn :: Sql DBOrd a => (i -> Expr a) -> Aggregator1 i (Expr a)
modeOn f = lmap f mode


-- | Corresponds to @percentile_disc(_) WITHIN GROUP (ORDER BY _)@.
percentile :: Sql DBOrd a => Expr Double -> Aggregator1 (Expr a) (Expr a)
percentile fraction =
unsafeMakeAggregator
(const fraction)
(castExpr . fromPrimExpr . fromColumn)
Empty
(Opaleye.withinGroup ((\(Order o) -> o) ascTable)
(Opaleye.makeAggrExplicit
unpackspec
(Opaleye.AggrOther "percentile_disc")))


-- | Applies 'percentile' to the column selected by the given function.
percentileOn ::
Sql DBOrd a =>
Expr Double ->
(i -> Expr a) ->
Aggregator1 i (Expr a)
percentileOn fraction f = lmap f (percentile fraction)


-- | Corresponds to @percentile_cont(_) WITHIN GROUP (ORDER BY _)@.
percentileContinuous ::
Sql DBFractional a =>
Expr Double ->
Aggregator1 (Expr a) (Expr a)
percentileContinuous fraction =
unsafeMakeAggregator
(const fraction)
(castExpr . fromPrimExpr . fromColumn)
Empty
(Opaleye.withinGroup ((\(Order o) -> o) (unsafeCastExpr @Double >$< asc))
(Opaleye.makeAggrExplicit
unpackspec
(Opaleye.AggrOther "percentile_cont")))


-- | Applies 'percentileContinuous' to the column selected by the given
-- function.
percentileContinuousOn ::
Sql DBFractional a =>
Expr Double ->
(i -> Expr a) ->
Aggregator1 i (Expr a)
percentileContinuousOn fraction f = lmap f (percentileContinuous fraction)


-- | Corresponds to @rank(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalRank ::
Order a ->
a ->
Aggregator' fold a (Expr Int64)
hypotheticalRank (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 1)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "rank")))


-- | Corresponds to @dense_rank(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalDenseRank ::
Order a ->
a ->
Aggregator' fold a (Expr Int64)
hypotheticalDenseRank (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 1)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "dense_rank")))


-- | Corresponds to @percent_rank(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalPercentRank ::
Order a ->
a ->
Aggregator' fold a (Expr Double)
hypotheticalPercentRank (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 0)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "percent_rank")))


-- | Corresponds to @cume_dist(_) WITHIN GROUP (ORDER BY _)@.
hypotheticalCumeDist ::
Order a ->
a ->
Aggregator' fold a (Expr Double)
hypotheticalCumeDist (Order order) args =
unsafeMakeAggregator
(const args)
(castExpr . fromPrimExpr . fromColumn)
(Fallback 1)
(Opaleye.withinGroup order
(Opaleye.makeAggrExplicit (fromOrder order)
(Opaleye.AggrOther "cume_dist")))


-- | Aggregate a value by grouping by it.
groupByExpr :: Sql DBEq a => Aggregator1 (Expr a) (Expr a)
groupByExpr =
Expand All @@ -249,7 +388,7 @@ groupByExpr =
Opaleye.groupBy


-- | Applies 'groupByExprOn' to the column selected by the given function.
-- | Applies 'groupByExpr' to the column selected by the given function.
groupByExprOn :: Sql DBEq a => (i -> Expr a) -> Aggregator1 i (Expr a)
groupByExprOn f = lmap f groupByExpr

Expand Down
2 changes: 1 addition & 1 deletion src/Rel8/Expr/Num.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fromIntegral :: (Sql DBIntegral a, Sql DBNum b, Homonullable a b)
fromIntegral (Expr a) = castExpr (Expr a)


-- | Cast 'DBNum' types to 'DBFractional' types. For example, his can be useful
-- | Cast 'DBNum' types to 'DBFractional' types. For example, this can be useful
-- to convert @Expr Float@ to @Expr Double@.
realToFrac :: (Sql DBNum a, Sql DBFractional b, Homonullable a b)
=> Expr a -> Expr b
Expand Down
15 changes: 0 additions & 15 deletions src/Rel8/Query/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ module Rel8.Query.Aggregate
( aggregate
, aggregate1
, countRows
, mode
)
where

-- base
import Control.Applicative (liftA2)
import Data.Functor.Contravariant ( (>$<) )
import Data.Int ( Int64 )
import Prelude

Expand All @@ -24,15 +22,10 @@ import Rel8.Aggregate (Aggregator' (Aggregator), Aggregator)
import Rel8.Aggregate.Fold (Fallback (Fallback))
import Rel8.Expr ( Expr )
import Rel8.Expr.Aggregate ( countStar )
import Rel8.Expr.Order ( desc )
import Rel8.Query ( Query )
import Rel8.Query.Limit ( limit )
import Rel8.Query.Maybe ( optional )
import Rel8.Query.Opaleye ( mapOpaleye )
import Rel8.Query.Order ( orderBy )
import Rel8.Table (Table)
import Rel8.Table.Aggregate (groupBy)
import Rel8.Table.Eq (EqTable)
import Rel8.Table.Maybe (fromMaybeTable)


Expand All @@ -55,11 +48,3 @@ aggregate1 (Aggregator _ aggregator) = mapOpaleye (Opaleye.aggregate aggregator)
-- will return @0@.
countRows :: Query a -> Query (Expr Int64)
countRows = aggregate countStar


-- | Return the most common row in a query.
mode :: forall a. EqTable a => Query a -> Query a
mode rows =
limit 1 $ fmap snd $
orderBy (fst >$< desc) $ do
aggregate1 (liftA2 (,) countStar groupBy) rows
11 changes: 11 additions & 0 deletions src/Rel8/Table/Opaleye.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ module Rel8.Table.Opaleye
, valuesspec
, view
, castTable
, fromOrder
)
where

-- base
import Data.Foldable (traverse_)
import Data.Functor.Const ( Const( Const ), getConst )
import Data.List.NonEmpty ( NonEmpty )
import Prelude
Expand All @@ -36,6 +38,9 @@ import qualified Opaleye.Adaptors as Opaleye
import qualified Opaleye.Field as Opaleye ( Field_ )
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.Operators as Opaleye
import qualified Opaleye.Internal.Order as Opaleye
import qualified Opaleye.Internal.PackMap as Opaleye
import qualified Opaleye.Internal.Unpackspec as Opaleye
import qualified Opaleye.Internal.Values as Opaleye
import qualified Opaleye.Table as Opaleye

Expand Down Expand Up @@ -153,3 +158,9 @@ castTable (toColumns -> as) = fromColumns $ htabulate \field ->
case hfield hspecs field of
Spec {info} -> case hfield as field of
expr -> scastExpr info expr


fromOrder :: Opaleye.Order a -> Opaleye.Unpackspec a a
fromOrder (Opaleye.Order o) =
Opaleye.Unpackspec $ Opaleye.PackMap $ \f a ->
a <$ traverse_ (f . snd) (o a)
5 changes: 4 additions & 1 deletion src/Rel8/Type/Eq.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ where
import Data.Aeson ( Value )

-- base
import Data.List.NonEmpty ( NonEmpty )
import Data.Fixed (Fixed)
import Data.Int ( Int16, Int32, Int64 )
import Data.Kind ( Constraint, Type )
import Data.List.NonEmpty ( NonEmpty )
import Prelude

-- bytestring
Expand All @@ -29,6 +30,7 @@ import Data.CaseInsensitive ( CI )
-- rel8
import Rel8.Schema.Null ( Sql )
import Rel8.Type ( DBType )
import Rel8.Type.Decimal (PowerOf10)

-- scientific
import Data.Scientific ( Scientific )
Expand Down Expand Up @@ -58,6 +60,7 @@ instance DBEq Char
instance DBEq Int16
instance DBEq Int32
instance DBEq Int64
instance PowerOf10 n => DBEq (Fixed n)
instance DBEq Float
instance DBEq Double
instance DBEq Scientific
Expand Down
4 changes: 4 additions & 0 deletions src/Rel8/Type/Num.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ module Rel8.Type.Num
where

-- base
import Data.Fixed (Fixed)
import Data.Int ( Int16, Int32, Int64 )
import Data.Kind ( Constraint, Type )
import Prelude

-- rel8
import Rel8.Type ( DBType )
import Rel8.Type.Decimal (PowerOf10)
import Rel8.Type.Ord ( DBOrd )

-- scientific
Expand All @@ -31,6 +33,7 @@ class DBType a => DBNum a
instance DBNum Int16
instance DBNum Int32
instance DBNum Int64
instance PowerOf10 n => DBNum (Fixed n)
instance DBNum Float
instance DBNum Double
instance DBNum Scientific
Expand All @@ -49,6 +52,7 @@ instance DBIntegral Int64
-- | The class of database types that support the @/@ operator.
type DBFractional :: Type -> Constraint
class DBNum a => DBFractional a
instance PowerOf10 n => DBFractional (Fixed n)
instance DBFractional Float
instance DBFractional Double
instance DBFractional Scientific
Expand Down
Loading

0 comments on commit c061533

Please sign in to comment.