diff --git a/src/Rel8/Query/List.hs b/src/Rel8/Query/List.hs index 8a76cdf6..55800df0 100644 --- a/src/Rel8/Query/List.hs +++ b/src/Rel8/Query/List.hs @@ -11,6 +11,7 @@ module Rel8.Query.List where -- base +import Control.Monad ((>=>)) import Data.Functor.Identity ( runIdentity ) import Data.List.NonEmpty ( NonEmpty ) import Prelude @@ -25,10 +26,11 @@ import Rel8.Expr.Opaleye ( mapPrimExpr ) import Rel8.Query ( Query ) import Rel8.Query.Aggregate (aggregate, aggregate1) import Rel8.Query.Rebind (hrebind, rebind) +import Rel8.Schema.HTable (hfield, hspecs, htabulate) import Rel8.Schema.HTable.Vectorize ( hunvectorize ) import Rel8.Schema.Null ( Sql, Unnullify ) import Rel8.Schema.Spec ( Spec( Spec, info ) ) -import Rel8.Table ( Table, fromColumns ) +import Rel8.Table (Table, fromColumns, toColumns) import Rel8.Table.Aggregate ( listAgg, nonEmptyAgg ) import Rel8.Table.List ( ListTable( ListTable ) ) import Rel8.Table.NonEmpty ( NonEmptyTable( NonEmptyTable ) ) @@ -76,8 +78,8 @@ someExpr = aggregate1 nonEmptyAggExpr -- @catListTable@ is an inverse to 'many'. catListTable :: Table Expr a => ListTable Expr a -> Query a catListTable (ListTable as) = - fmap fromColumns $ hrebind "unnest" $ runIdentity $ - hunvectorize (\Spec {info} -> pure . sunnest info) as + (>>= extract) $ fmap fromColumns $ hrebind "unnest" $ runIdentity $ + hunvectorize (\_ -> pure . unnest) as -- | Expand a 'NonEmptyTable' into a 'Query', where each row in the query is an @@ -86,8 +88,8 @@ catListTable (ListTable as) = -- @catNonEmptyTable@ is an inverse to 'some'. catNonEmptyTable :: Table Expr a => NonEmptyTable Expr a -> Query a catNonEmptyTable (NonEmptyTable as) = - fmap fromColumns $ hrebind "unnest" $ runIdentity $ - hunvectorize (\Spec {info} -> pure . sunnest info) as + (>>= extract) $ fmap fromColumns $ hrebind "unnest" $ runIdentity $ + hunvectorize (\_ -> pure . unnest) as -- | Expand an expression that contains a list into a 'Query', where each row @@ -95,7 +97,7 @@ catNonEmptyTable (NonEmptyTable as) = -- -- @catList@ is an inverse to 'manyExpr'. catList :: Sql DBType a => Expr [a] -> Query (Expr a) -catList = rebind "unnest" . sunnest typeInformation +catList = rebind "unnest" . unnest >=> extract -- | Expand an expression that contains a non-empty list into a 'Query', where @@ -103,10 +105,17 @@ catList = rebind "unnest" . sunnest typeInformation -- -- @catNonEmpty@ is an inverse to 'someExpr'. catNonEmpty :: Sql DBType a => Expr (NonEmpty a) -> Query (Expr a) -catNonEmpty = rebind "unnest" . sunnest typeInformation +catNonEmpty = rebind "unnest" . unnest >=> extract -sunnest :: TypeInformation (Unnullify a) -> Expr (list a) -> Expr a -sunnest info = mapPrimExpr $ - extractArrayElement info . - Opaleye.UnExpr (Opaleye.UnOpOther "UNNEST") +unnest :: Expr (list a) -> Expr a +unnest = mapPrimExpr $ Opaleye.UnExpr (Opaleye.UnOpOther "UNNEST") + + +extract :: Table Expr a => a -> Query a +extract = rebind "extract" . fromColumns . go . toColumns + where + go as = htabulate $ \field -> + case hfield as field of + a -> case hfield hspecs field of + Spec {info} -> mapPrimExpr (extractArrayElement info) a diff --git a/src/Rel8/Type/Array.hs b/src/Rel8/Type/Array.hs index 0b67d66a..e170bf1d 100644 --- a/src/Rel8/Type/Array.hs +++ b/src/Rel8/Type/Array.hs @@ -2,6 +2,7 @@ {-# language LambdaCase #-} {-# language NamedFieldPuns #-} {-# language OverloadedStrings #-} +{-# language TypeApplications #-} {-# language ViewPatterns #-} module Rel8.Type.Array @@ -96,5 +97,33 @@ encodeArrayElement info extractArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr extractArrayElement info - | isArray info = flip Opaleye.CompositeExpr "f1" + | isArray info = extract | otherwise = id + where + extract input = cast unrow + where + string = Opaleye.ConstExpr . Opaleye.StringLit + int = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger @Int + minus a b = Opaleye.BinExpr (Opaleye.:-) a b + len = Opaleye.FunExpr "length" . pure + substr s a b = Opaleye.FunExpr "substr" [s, a, b] + cast = Opaleye.CastExpr (typeName info) + text = Opaleye.CastExpr "text" input + unrow = + Opaleye.CaseExpr + [ (quoted, unquote) + ] + unparen + where + quoted = Opaleye.BinExpr Opaleye.OpLike text pattern + where + pattern = string "(\"%\")" + unparen = unwrap 1 + unwrap n = substr text (int (1 + n)) (minus (len text) (int (n * 2))) + unquote = unescape '"' $ unescape '\\' $ unwrap 2 + where + unescape char a = + Opaleye.FunExpr "replace" [a, pattern, replacement] + where + pattern = string [char, char] + replacement = string [char] diff --git a/tests/Main.hs b/tests/Main.hs index e4b43491..b0973670 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -2,7 +2,7 @@ {-# language BlockArguments #-} {-# language DeriveAnyClass #-} {-# language DeriveGeneric #-} -{-# language DerivingStrategies #-} +{-# language DerivingVia #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} {-# language MonoLocalBinds #-} @@ -23,6 +23,7 @@ import Control.Applicative ( empty, liftA2, liftA3 ) import Control.Exception ( bracket, throwIO ) import Control.Monad ( (>=>), void ) import Data.Bifunctor ( bimap ) +import Data.Fixed (Fixed (MkFixed)) import Data.Foldable ( for_ ) import Data.Int ( Int32, Int64 ) import Data.List ( nub, sort ) @@ -57,6 +58,10 @@ import qualified Hedgehog.Range as Range -- mmorph import Control.Monad.Morph ( hoist ) +-- network-ip +import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) +import Data.DoubleWord (Word128(..)) + -- rel8 import Rel8 ( Result ) import qualified Rel8 @@ -87,9 +92,6 @@ import qualified Database.Postgres.Temp as TmpPostgres -- uuid import qualified Data.UUID --- ip -import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) -import Data.DoubleWord (Word128(..)) main :: IO () main = defaultMain tests @@ -143,6 +145,7 @@ tests = sql "CREATE TABLE test_table ( column1 text not null, column2 bool not null )" sql "CREATE TABLE unique_table ( \"key\" text not null unique, \"value\" text not null )" sql "CREATE SEQUENCE test_seq" + sql "CREATE TYPE composite AS (\"bool\" bool, \"char\" char, \"array\" int4[])" return db @@ -409,21 +412,37 @@ testAp = databasePropertyTest "Cartesian product (<*>)" \transaction -> do sort result === sort (liftA2 (,) rows1 rows2) +data Composite = Composite + { bool :: !Bool + , char :: !Char + , array :: ![Int32] + } + deriving (Eq, Show, Generic) + deriving (Rel8.DBType) via Rel8.Composite Composite + + +instance Rel8.DBComposite Composite where + compositeTypeName = "composite" + compositeFields = Rel8.namesFromLabels + + testDBType :: IO TmpPostgres.DB -> TestTree testDBType getTestDatabase = testGroup "DBType instances" [ dbTypeTest "Bool" Gen.bool , dbTypeTest "ByteString" $ Gen.bytes (Range.linear 0 128) + , dbTypeTest "CalendarDiffTime" genCalendarDiffTime , dbTypeTest "CI Lazy Text" $ mk . Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "CI Text" $ mk <$> Gen.text (Range.linear 0 10) Gen.unicode + , dbTypeTest "Composite" genComposite , dbTypeTest "Day" genDay - , dbTypeTest "Double" $ (/10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) - , dbTypeTest "Float" $ (/10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Double" $ (/ 10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Float" $ (/ 10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Int32" $ Gen.integral @_ @Int32 Range.linearBounded , dbTypeTest "Int64" $ Gen.integral @_ @Int64 Range.linearBounded , dbTypeTest "Lazy ByteString" $ Data.ByteString.Lazy.fromStrict <$> Gen.bytes (Range.linear 0 128) , dbTypeTest "Lazy Text" $ Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "LocalTime" genLocalTime - , dbTypeTest "Scientific" $ (/10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Scientific" $ (/ 10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Text" $ Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "TimeOfDay" genTimeOfDay , dbTypeTest "UTCTime" $ UTCTime <$> genDay <*> genDiffTime @@ -432,24 +451,53 @@ testDBType getTestDatabase = testGroup "DBType instances" ] where - dbTypeTest :: (Eq a, Show a, Rel8.DBType a) => TestName -> Gen a -> TestTree + dbTypeTest :: (Eq a, Show a, Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => TestName -> Gen a -> TestTree dbTypeTest name generator = testGroup name [ databasePropertyTest name (t generator) getTestDatabase , databasePropertyTest ("Maybe " <> name) (t (Gen.maybe generator)) getTestDatabase ] - t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a) + t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => Gen a -> (TestT Transaction () -> PropertyT IO b) -> PropertyT IO b t generator transaction = do x <- forAll generator + y <- forAll generator transaction do [res] <- lift do statement () $ Rel8.select do pure (Rel8.litExpr x) diff res (==) x + [res'] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res' (==) [[x, y]] + [res3] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res3 (==) [[[x, y]]] + res'' <- lift do + statement () $ Rel8.select do + xs <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]) + Rel8.catListTable xs + diff res'' (==) [x, y] +{- + res''' <- lift do + statement () $ Rel8.select do + xss <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]]) + xs <- Rel8.catListTable xss + Rel8.catListTable xs + diff res''' (==) [x, y] +-} + + genComposite :: Gen Composite + genComposite = do + bool <- Gen.bool + char <- Gen.unicode + array <- Gen.list (Range.linear 0 10) (Gen.int32 (Range.linear (-10000) 10000)) + pure Composite {..} genDay :: Gen Day genDay = do @@ -458,6 +506,14 @@ testDBType getTestDatabase = testGroup "DBType instances" day <- Gen.integral (Range.linear 1 31) Gen.just $ pure $ fromGregorianValid year month day + genCalendarDiffTime :: Gen CalendarDiffTime + genCalendarDiffTime = do + -- hardcoded to 0 because Hasql's 'interval' decoder needs to return a + -- CalendarDiffTime for this to be properly round-trippable + months <- pure 0 -- Gen.integral (Range.linear 0 120) + diffTime <- secondsToNominalDiffTime . MkFixed . (* 1000000) <$> Gen.integral (Range.linear 0 2147483647999999) + pure $ CalendarDiffTime months diffTime + genDiffTime :: Gen DiffTime genDiffTime = secondsToDiffTime <$> Gen.integral (Range.linear 0 86401)