From 0ebb70f95f9a1bec939bd53d36448809fda79bc6 Mon Sep 17 00:00:00 2001 From: Shane Date: Wed, 27 Sep 2023 11:41:27 +0100 Subject: [PATCH] Add array concenation aggregators (#270) --- .../20230826_034632_shane.obrien_array_cat.md | 40 +++++++++++++ rel8.cabal | 2 + src/Rel8.hs | 2 + src/Rel8/Expr/Aggregate.hs | 58 ++++++++++++++++++- src/Rel8/Expr/Opaleye.hs | 16 ++--- src/Rel8/Expr/Read.hs | 29 ++++++++++ src/Rel8/Expr/Show.hs | 18 ++++++ src/Rel8/Schema/HTable.hs | 6 ++ src/Rel8/Schema/HTable/Vectorize.hs | 17 ++++++ src/Rel8/Table/Aggregate.hs | 36 +++++++++++- src/Rel8/Type/Array.hs | 9 ++- tests/Main.hs | 19 +++++- 12 files changed, 236 insertions(+), 16 deletions(-) create mode 100644 changelog.d/20230826_034632_shane.obrien_array_cat.md create mode 100644 src/Rel8/Expr/Read.hs create mode 100644 src/Rel8/Expr/Show.hs diff --git a/changelog.d/20230826_034632_shane.obrien_array_cat.md b/changelog.d/20230826_034632_shane.obrien_array_cat.md new file mode 100644 index 00000000..6c975b25 --- /dev/null +++ b/changelog.d/20230826_034632_shane.obrien_array_cat.md @@ -0,0 +1,40 @@ + + + +### Added + +- Added aggregators `listCat` and `nonEmptyCat` for folding a collection of lists into a single list by concatenation. + + + + + diff --git a/rel8.cabal b/rel8.cabal index 54d24e38..07e9bc5b 100644 --- a/rel8.cabal +++ b/rel8.cabal @@ -92,8 +92,10 @@ library Rel8.Expr.Opaleye Rel8.Expr.Ord Rel8.Expr.Order + Rel8.Expr.Read Rel8.Expr.Sequence Rel8.Expr.Serialize + Rel8.Expr.Show Rel8.Expr.Window Rel8.FCF diff --git a/src/Rel8.hs b/src/Rel8.hs index 74573e64..7c13e966 100644 --- a/src/Rel8.hs +++ b/src/Rel8.hs @@ -279,8 +279,10 @@ module Rel8 , countRows , groupBy, groupByOn , listAgg, listAggOn, listAggExpr, listAggExprOn + , listCat, listCatOn, listCatExpr, listCatExprOn , mode , nonEmptyAgg, nonEmptyAggOn, nonEmptyAggExpr, nonEmptyAggExprOn + , nonEmptyCat, nonEmptyCatOn, nonEmptyCatExpr, nonEmptyCatExprOn , DBMax, max, maxOn , DBMin, min, minOn , DBSum, sum, sumOn, sumWhere, avg, avgOn diff --git a/src/Rel8/Expr/Aggregate.hs b/src/Rel8/Expr/Aggregate.hs index 5b87fdb7..d3c74488 100644 --- a/src/Rel8/Expr/Aggregate.hs +++ b/src/Rel8/Expr/Aggregate.hs @@ -1,5 +1,7 @@ {-# language DataKinds #-} +{-# language DisambiguateRecordFields #-} {-# language FlexibleContexts #-} +{-# language NamedFieldPuns #-} {-# language OverloadedStrings #-} {-# language ScopedTypeVariables #-} {-# language TypeFamilies #-} @@ -19,7 +21,9 @@ module Rel8.Expr.Aggregate , distinctAggregate , filterWhereExplicit , listAggExpr, listAggExprOn, nonEmptyAggExpr, nonEmptyAggExprOn + , listCatExpr, listCatExprOn, nonEmptyCatExpr, nonEmptyCatExprOn , slistAggExpr, snonEmptyAggExpr + , slistCatExpr, snonEmptyCatExpr ) where @@ -27,7 +31,7 @@ where import Data.Int ( Int64 ) import Data.List.NonEmpty ( NonEmpty ) import Data.String (IsString) -import Prelude hiding ( and, max, min, null, or, sum ) +import Prelude hiding (and, max, min, null, or, show, sum) -- opaleye import qualified Opaleye.Aggregate as Opaleye @@ -48,6 +52,7 @@ import Rel8.Aggregate.Fold (Fallback (Empty, Fallback)) import Rel8.Expr ( Expr ) import Rel8.Expr.Array (sempty) import Rel8.Expr.Bool (false, true) +import Rel8.Expr.Eq ((/=.)) import Rel8.Expr.Opaleye ( castExpr , fromColumn @@ -55,11 +60,14 @@ import Rel8.Expr.Opaleye , toColumn , toPrimExpr ) +import Rel8.Expr.Read (sread) +import Rel8.Expr.Show (show) +import qualified Rel8.Expr.Text as Text import Rel8.Schema.Null ( Sql, Unnullify ) import Rel8.Type ( DBType, typeInformation ) -import Rel8.Type.Array ( encodeArrayElement ) +import Rel8.Type.Array (arrayTypeName, encodeArrayElement) import Rel8.Type.Eq ( DBEq ) -import Rel8.Type.Information ( TypeInformation ) +import Rel8.Type.Information (TypeInformation) import Rel8.Type.Num ( DBNum ) import Rel8.Type.Ord ( DBMax, DBMin ) import Rel8.Type.String ( DBString ) @@ -267,6 +275,29 @@ nonEmptyAggExprOn :: Sql DBType a nonEmptyAggExprOn f = lmap f nonEmptyAggExpr +-- | Concatenate lists into a single list. +listCatExpr :: Sql DBType a => Aggregator' fold (Expr [a]) (Expr [a]) +listCatExpr = slistCatExpr typeInformation + + +-- | Applies 'listCatExpr' to the column selected by the given function. +listCatExprOn :: Sql DBType a + => (i -> Expr [a]) -> Aggregator' fold i (Expr [a]) +listCatExprOn f = lmap f listCatExpr + + +-- | Concatenate non-empty lists into a single non-empty list. +nonEmptyCatExpr :: Sql DBType a + => Aggregator1 (Expr (NonEmpty a)) (Expr (NonEmpty a)) +nonEmptyCatExpr = snonEmptyCatExpr typeInformation + + +-- | Applies 'nonEmptyCatExpr' to the column selected by the given function. +nonEmptyCatExprOn :: Sql DBType a + => (i -> Expr (NonEmpty a)) -> Aggregator1 i (Expr (NonEmpty a)) +nonEmptyCatExprOn f = lmap f nonEmptyCatExpr + + -- | 'distinctAggregate' modifies an 'Aggregator' to consider only distinct -- values of a particular column. distinctAggregate :: Sql DBEq a @@ -295,6 +326,27 @@ snonEmptyAggExpr info = Opaleye.arrayAgg +slistCatExpr :: () + => TypeInformation (Unnullify a) -> Aggregator' fold (Expr [a]) (Expr [a]) +slistCatExpr info = dimap (unbracket . show) (sread name . bracket) agg + where + bracket a = "{" <> a <> "}" + unbracket a = Text.substr a 2 (Just (Text.length a - 2)) + agg = filterWhereExplicit ifPP (/=. "") (stringAgg ",") + name = arrayTypeName info + + +snonEmptyCatExpr :: () + => TypeInformation (Unnullify a) + -> Aggregator1 (Expr (NonEmpty a)) (Expr (NonEmpty a)) +snonEmptyCatExpr info = dimap (unbracket . show) (sread name . bracket) agg + where + bracket a = "{" <> a <> "}" + unbracket a = Text.substr a 2 (Just (Text.length a - 2)) + agg = filterWhereExplicit ifPP (/=. "") (stringAgg ",") + name = arrayTypeName info + + ifPP :: Opaleye.IfPP (Expr a) (Expr a) ifPP = dimap from to Opaleye.ifPPField where diff --git a/src/Rel8/Expr/Opaleye.hs b/src/Rel8/Expr/Opaleye.hs index afd17823..30e7c89a 100644 --- a/src/Rel8/Expr/Opaleye.hs +++ b/src/Rel8/Expr/Opaleye.hs @@ -1,6 +1,7 @@ {-# language FlexibleContexts #-} {-# language NamedFieldPuns #-} {-# language ScopedTypeVariables #-} +{-# language TypeApplications #-} {-# language TypeFamilies #-} {-# options_ghc -fno-warn-redundant-constraints #-} @@ -26,7 +27,7 @@ import {-# SOURCE #-} Rel8.Expr ( Expr( Expr ) ) import Rel8.Schema.Null ( Unnullify, Sql ) import Rel8.Type ( DBType, typeInformation ) import Rel8.Type.Information ( TypeInformation(..) ) -import Rel8.Type.Name (showTypeName) +import Rel8.Type.Name (TypeName, showTypeName) -- profunctors import Data.Profunctor ( Profunctor, dimap ) @@ -38,18 +39,19 @@ castExpr = scastExpr typeInformation -- | Cast an expression to a different type. Corresponds to a @CAST()@ function -- call. -unsafeCastExpr :: Sql DBType b => Expr a -> Expr b -unsafeCastExpr = sunsafeCastExpr typeInformation +unsafeCastExpr :: forall b a. Sql DBType b => Expr a -> Expr b +unsafeCastExpr = case typeInformation @(Unnullify b) of + TypeInformation {typeName} -> sunsafeCastExpr typeName scastExpr :: TypeInformation (Unnullify a) -> Expr a -> Expr a -scastExpr = sunsafeCastExpr +scastExpr TypeInformation {typeName} = sunsafeCastExpr typeName sunsafeCastExpr :: () - => TypeInformation (Unnullify b) -> Expr a -> Expr b -sunsafeCastExpr TypeInformation {typeName} = - fromPrimExpr . Opaleye.CastExpr (showTypeName typeName) . toPrimExpr + => TypeName -> Expr a -> Expr b +sunsafeCastExpr name = + fromPrimExpr . Opaleye.CastExpr (showTypeName name) . toPrimExpr -- | Unsafely construct an expression from literal SQL. diff --git a/src/Rel8/Expr/Read.hs b/src/Rel8/Expr/Read.hs new file mode 100644 index 00000000..7dea9511 --- /dev/null +++ b/src/Rel8/Expr/Read.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE MonoLocalBinds #-} + +module Rel8.Expr.Read + ( read + , sread + ) +where + +-- base +import Prelude () + +-- rel8 +import Rel8.Expr (Expr) +import Rel8.Expr.Opaleye (unsafeCastExpr, sunsafeCastExpr) +import Rel8.Schema.Null (Sql) +import Rel8.Type (DBType) +import Rel8.Type.Name (TypeName) + +-- text +import Data.Text (Text) + + +read :: Sql DBType a => Expr Text -> Expr a +read = unsafeCastExpr + + +sread :: TypeName -> Expr Text -> Expr a +sread = sunsafeCastExpr \ No newline at end of file diff --git a/src/Rel8/Expr/Show.hs b/src/Rel8/Expr/Show.hs new file mode 100644 index 00000000..abb385f1 --- /dev/null +++ b/src/Rel8/Expr/Show.hs @@ -0,0 +1,18 @@ +module Rel8.Expr.Show + ( show + ) +where + +-- base +import Prelude () + +-- rel8 +import Rel8.Expr (Expr) +import Rel8.Expr.Opaleye (unsafeCastExpr) + +-- text +import Data.Text (Text) + + +show :: Expr a -> Expr Text +show = unsafeCastExpr \ No newline at end of file diff --git a/src/Rel8/Schema/HTable.hs b/src/Rel8/Schema/HTable.hs index 62e8a76a..36e9fb22 100644 --- a/src/Rel8/Schema/HTable.hs +++ b/src/Rel8/Schema/HTable.hs @@ -135,23 +135,29 @@ htabulateA :: (HTable t, Apply m) htabulateA f = htraverse getCompose $ htabulate $ Compose . f {-# INLINABLE htabulateA #-} + newtype ApplyP p a b = ApplyP { unApplyP :: p a b } + instance Profunctor p => Functor (ApplyP p a) where fmap f = ApplyP . rmap f . unApplyP + instance ProductProfunctor p => Apply (ApplyP p a) where ApplyP f <.> ApplyP x = ApplyP (rmap id f **** x) + htraverseP :: (HTable t, ProductProfunctor p) => (forall a. p (f a) (g a)) -> p (t f) (t g) htraverseP f = htraversePWithField (const f) + htraversePWithField :: (HTable t, ProductProfunctor p) => (forall a. HField t a -> p (f a) (g a)) -> p (t f) (t g) htraversePWithField f = unApplyP $ htabulateA $ \field -> ApplyP $ lmap (flip hfield field) (f field) + type GHField :: K.HTable -> Type -> Type newtype GHField t a = GHField (HField (GHColumns (Rep (t Proxy))) a) diff --git a/src/Rel8/Schema/HTable/Vectorize.hs b/src/Rel8/Schema/HTable/Vectorize.hs index ea58771a..a0fc7ee3 100644 --- a/src/Rel8/Schema/HTable/Vectorize.hs +++ b/src/Rel8/Schema/HTable/Vectorize.hs @@ -25,6 +25,7 @@ module Rel8.Schema.HTable.Vectorize , hnullify , happend, hempty , hproject + , htraverseVectorP , hcolumn , First (..) ) @@ -37,12 +38,19 @@ import qualified Data.Semigroup as Base import GHC.Generics (Generic) import Prelude +-- product-profunctors +import Data.Profunctor.Product (ProductProfunctor) + +-- profunctors +import Data.Profunctor (dimap) + -- rel8 import Rel8.FCF ( Eval, Exp ) import Rel8.Schema.Dict ( Dict( Dict ) ) import qualified Rel8.Schema.Kind as K import Rel8.Schema.HTable ( HField, HTable, hfield, htabulate, htabulateA, hspecs + , htraversePWithField ) import Rel8.Schema.HTable.Identity ( HIdentity( HIdentity ) ) import Rel8.Schema.HTable.MapTable @@ -161,6 +169,15 @@ hproject :: () hproject f (HVectorize a) = HVectorize (HMapTable.hproject f a) +htraverseVectorP :: (HTable t, ProductProfunctor p) + => (forall a. HField t a -> p (f (list a)) (g (list' a))) + -> p (HVectorize list t f) (HVectorize list' t g) +htraverseVectorP f = + dimap (\(HVectorize (HMapTable a)) -> a) (HVectorize . HMapTable) $ + htraversePWithField $ \field -> + dimap (\(Precompose a) -> a) Precompose (f field) + + hcolumn :: HVectorize list (HIdentity a) context -> context (list a) hcolumn (HVectorize (HMapTable (HIdentity (Precompose a)))) = a diff --git a/src/Rel8/Table/Aggregate.hs b/src/Rel8/Table/Aggregate.hs index 50185550..8437d69a 100644 --- a/src/Rel8/Table/Aggregate.hs +++ b/src/Rel8/Table/Aggregate.hs @@ -8,6 +8,7 @@ module Rel8.Table.Aggregate ( groupBy, groupByOn , listAgg, listAggOn, nonEmptyAgg, nonEmptyAggOn + , listCat, listCatOn, nonEmptyCat, nonEmptyCatOn , filterWhere, filterWhereOptional , orderAggregateBy , optionalAggregate @@ -34,13 +35,15 @@ import Rel8.Expr.Aggregate ( filterWhereExplicit , groupByExprOn , slistAggExpr + , slistCatExpr , snonEmptyAggExpr + , snonEmptyCatExpr ) import Rel8.Expr.Opaleye (toColumn, toPrimExpr) import Rel8.Order (Order (Order)) import Rel8.Schema.Dict ( Dict( Dict ) ) -import Rel8.Schema.HTable (HTable, hfield, htabulateA) -import Rel8.Schema.HTable.Vectorize (hvectorizeA) +import Rel8.Schema.HTable (HTable, hfield, hspecs, htabulateA) +import Rel8.Schema.HTable.Vectorize (htraverseVectorP, hvectorizeA) import Rel8.Schema.Null ( Sql ) import Rel8.Schema.Spec ( Spec( Spec, info ) ) import Rel8.Table (Table, toColumns, fromColumns) @@ -146,6 +149,35 @@ nonEmptyAggOn :: Table Expr a nonEmptyAggOn f = lmap f nonEmptyAgg +-- | Concatenate lists into a single list. +listCat :: Table Expr a + => Aggregator' fold (ListTable Expr a) (ListTable Expr a) +listCat = dimap toColumns fromColumns $ + htraverseVectorP (\field -> case hfield hspecs field of + Spec {info} -> slistCatExpr info) + + +-- | Applies 'listCat' to the list selected by the given function. +listCatOn :: Table Expr a + => (i -> ListTable Expr a) -> Aggregator' fold i (ListTable Expr a) +listCatOn f = lmap f listCat + + +-- | Concatenate non-empty lists into a single non-empty list. +nonEmptyCat :: Table Expr a + => Aggregator1 (NonEmptyTable Expr a) (NonEmptyTable Expr a) +nonEmptyCat = dimap toColumns fromColumns $ + htraverseVectorP (\field -> case hfield hspecs field of + Spec {info} -> snonEmptyCatExpr info) + + +-- | Applies 'nonEmptyCat' to the non-empty list selected by the given +-- function. +nonEmptyCatOn :: Table Expr a + => (i -> NonEmptyTable Expr a) -> Aggregator1 i (NonEmptyTable Expr a) +nonEmptyCatOn f = lmap f nonEmptyCat + + -- | Order the values within each aggregation in an `Aggregator` using the -- given ordering. This is only relevant for aggregations that depend on the -- order they get their elements, like `Rel8.listAgg` and `Rel8.stringAgg`. diff --git a/src/Rel8/Type/Array.hs b/src/Rel8/Type/Array.hs index da6bcf0c..304a2984 100644 --- a/src/Rel8/Type/Array.hs +++ b/src/Rel8/Type/Array.hs @@ -7,6 +7,7 @@ module Rel8.Type.Array ( array, encodeArrayElement, extractArrayElement + , arrayTypeName , listTypeInformation , nonEmptyTypeInformation , head, last, length @@ -75,7 +76,7 @@ listTypeInformation nullity info@TypeInformation {encode, decode} = NotNull -> Opaleye.ArrayExpr . fmap (encodeArrayElement info . encode) - , typeName = (arrayType info) {arrayDepth = 1} + , typeName = arrayTypeName info } where null = Opaleye.ConstExpr Opaleye.NullLit @@ -92,6 +93,10 @@ nonEmptyTypeInformation nullity = message = "failed to decode NonEmptyList: got empty list" +arrayTypeName :: TypeInformation a -> TypeName +arrayTypeName info = (arrayType info) {arrayDepth = 1} + + isArray :: TypeInformation a -> Bool isArray = (> 0) . arrayDepth . typeName @@ -111,7 +116,7 @@ decodeArrayElement info encodeArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr encodeArrayElement info - | isArray info = Opaleye.CastExpr "text" + | isArray info = Opaleye.CastExpr "text" . Opaleye.CastExpr (showTypeName (typeName info)) | otherwise = id diff --git a/tests/Main.hs b/tests/Main.hs index d37faafd..1dc6eafa 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -466,6 +466,8 @@ testDBType getTestDatabase = testGroup "DBType instances" t generator transaction = do x <- forAll generator y <- forAll generator + xss <- forAll $ Gen.list (Range.linear 0 10) (Gen.list (Range.linear 0 10) generator) + xsss <- forAll $ Gen.list (Range.linear 0 10) (Gen.list (Range.linear 0 10) (Gen.list (Range.linear 0 10) generator)) transaction do res <- lift do @@ -487,10 +489,23 @@ testDBType getTestDatabase = testGroup "DBType instances" diff res'' (==) [x, y] res''' <- lift do statement () $ Rel8.run $ Rel8.select do - xss <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]]) - xs <- Rel8.catListTable xss + xss' <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]]) + xs <- Rel8.catListTable xss' Rel8.catListTable xs diff res''' (==) [x, y] + res'''' <- lift do + statement () $ Rel8.run1 $ Rel8.select $ + Rel8.aggregate Rel8.listCatExpr $ + Rel8.values $ map Rel8.litExpr xss + diff res'''' (==) (concat xss) + res''''' <- lift do + statement () $ Rel8.run1 $ Rel8.select $ + Rel8.aggregate Rel8.listCatExpr $ + Rel8.values $ map Rel8.litExpr xsss + diff res''''' (==) (concat xsss) + + + genComposite :: Gen Composite genComposite = do