From 9fb2e98b1dfaef61065870b7dc1f537bf9b1da7e Mon Sep 17 00:00:00 2001 From: Shane O'Brien Date: Sat, 17 Jun 2023 11:58:11 +0100 Subject: [PATCH] Support "rank 2" `catListTable` (by "parsing" anonymous record) This is another possible "fix" to #168 (as opposed to #242). It doesn't really fix the problem, but it allows us to use two levels of `catListTable` instead of only one. Instead of trying to use Postgres's broken `.f1` syntax, we cast the anonymous record to text, remove the parentheses and quotes and unescape any escaped quotes or backslashes, and then cast the resulting text back to the appropriate type. The reason this only works one level deep is that if the type we cast the text back to is itself an anonymous record, then PostgreSQL doesn't know how to parse the text. It's kind of ugly and hacky but it does work and otherwise maintains the status quo. Comparison operators on nested lists continue to work as before and we don't need to burden `DBType` with parsing nonsense. --- src/Rel8/Query/List.hs | 31 +++++++++++------- src/Rel8/Type/Array.hs | 31 +++++++++++++++++- tests/Main.hs | 74 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 115 insertions(+), 21 deletions(-) diff --git a/src/Rel8/Query/List.hs b/src/Rel8/Query/List.hs index 8a76cdf6..55800df0 100644 --- a/src/Rel8/Query/List.hs +++ b/src/Rel8/Query/List.hs @@ -11,6 +11,7 @@ module Rel8.Query.List where -- base +import Control.Monad ((>=>)) import Data.Functor.Identity ( runIdentity ) import Data.List.NonEmpty ( NonEmpty ) import Prelude @@ -25,10 +26,11 @@ import Rel8.Expr.Opaleye ( mapPrimExpr ) import Rel8.Query ( Query ) import Rel8.Query.Aggregate (aggregate, aggregate1) import Rel8.Query.Rebind (hrebind, rebind) +import Rel8.Schema.HTable (hfield, hspecs, htabulate) import Rel8.Schema.HTable.Vectorize ( hunvectorize ) import Rel8.Schema.Null ( Sql, Unnullify ) import Rel8.Schema.Spec ( Spec( Spec, info ) ) -import Rel8.Table ( Table, fromColumns ) +import Rel8.Table (Table, fromColumns, toColumns) import Rel8.Table.Aggregate ( listAgg, nonEmptyAgg ) import Rel8.Table.List ( ListTable( ListTable ) ) import Rel8.Table.NonEmpty ( NonEmptyTable( NonEmptyTable ) ) @@ -76,8 +78,8 @@ someExpr = aggregate1 nonEmptyAggExpr -- @catListTable@ is an inverse to 'many'. catListTable :: Table Expr a => ListTable Expr a -> Query a catListTable (ListTable as) = - fmap fromColumns $ hrebind "unnest" $ runIdentity $ - hunvectorize (\Spec {info} -> pure . sunnest info) as + (>>= extract) $ fmap fromColumns $ hrebind "unnest" $ runIdentity $ + hunvectorize (\_ -> pure . unnest) as -- | Expand a 'NonEmptyTable' into a 'Query', where each row in the query is an @@ -86,8 +88,8 @@ catListTable (ListTable as) = -- @catNonEmptyTable@ is an inverse to 'some'. catNonEmptyTable :: Table Expr a => NonEmptyTable Expr a -> Query a catNonEmptyTable (NonEmptyTable as) = - fmap fromColumns $ hrebind "unnest" $ runIdentity $ - hunvectorize (\Spec {info} -> pure . sunnest info) as + (>>= extract) $ fmap fromColumns $ hrebind "unnest" $ runIdentity $ + hunvectorize (\_ -> pure . unnest) as -- | Expand an expression that contains a list into a 'Query', where each row @@ -95,7 +97,7 @@ catNonEmptyTable (NonEmptyTable as) = -- -- @catList@ is an inverse to 'manyExpr'. catList :: Sql DBType a => Expr [a] -> Query (Expr a) -catList = rebind "unnest" . sunnest typeInformation +catList = rebind "unnest" . unnest >=> extract -- | Expand an expression that contains a non-empty list into a 'Query', where @@ -103,10 +105,17 @@ catList = rebind "unnest" . sunnest typeInformation -- -- @catNonEmpty@ is an inverse to 'someExpr'. catNonEmpty :: Sql DBType a => Expr (NonEmpty a) -> Query (Expr a) -catNonEmpty = rebind "unnest" . sunnest typeInformation +catNonEmpty = rebind "unnest" . unnest >=> extract -sunnest :: TypeInformation (Unnullify a) -> Expr (list a) -> Expr a -sunnest info = mapPrimExpr $ - extractArrayElement info . - Opaleye.UnExpr (Opaleye.UnOpOther "UNNEST") +unnest :: Expr (list a) -> Expr a +unnest = mapPrimExpr $ Opaleye.UnExpr (Opaleye.UnOpOther "UNNEST") + + +extract :: Table Expr a => a -> Query a +extract = rebind "extract" . fromColumns . go . toColumns + where + go as = htabulate $ \field -> + case hfield as field of + a -> case hfield hspecs field of + Spec {info} -> mapPrimExpr (extractArrayElement info) a diff --git a/src/Rel8/Type/Array.hs b/src/Rel8/Type/Array.hs index 0b67d66a..e170bf1d 100644 --- a/src/Rel8/Type/Array.hs +++ b/src/Rel8/Type/Array.hs @@ -2,6 +2,7 @@ {-# language LambdaCase #-} {-# language NamedFieldPuns #-} {-# language OverloadedStrings #-} +{-# language TypeApplications #-} {-# language ViewPatterns #-} module Rel8.Type.Array @@ -96,5 +97,33 @@ encodeArrayElement info extractArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr extractArrayElement info - | isArray info = flip Opaleye.CompositeExpr "f1" + | isArray info = extract | otherwise = id + where + extract input = cast unrow + where + string = Opaleye.ConstExpr . Opaleye.StringLit + int = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger @Int + minus a b = Opaleye.BinExpr (Opaleye.:-) a b + len = Opaleye.FunExpr "length" . pure + substr s a b = Opaleye.FunExpr "substr" [s, a, b] + cast = Opaleye.CastExpr (typeName info) + text = Opaleye.CastExpr "text" input + unrow = + Opaleye.CaseExpr + [ (quoted, unquote) + ] + unparen + where + quoted = Opaleye.BinExpr Opaleye.OpLike text pattern + where + pattern = string "(\"%\")" + unparen = unwrap 1 + unwrap n = substr text (int (1 + n)) (minus (len text) (int (n * 2))) + unquote = unescape '"' $ unescape '\\' $ unwrap 2 + where + unescape char a = + Opaleye.FunExpr "replace" [a, pattern, replacement] + where + pattern = string [char, char] + replacement = string [char] diff --git a/tests/Main.hs b/tests/Main.hs index e4b43491..b0973670 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -2,7 +2,7 @@ {-# language BlockArguments #-} {-# language DeriveAnyClass #-} {-# language DeriveGeneric #-} -{-# language DerivingStrategies #-} +{-# language DerivingVia #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} {-# language MonoLocalBinds #-} @@ -23,6 +23,7 @@ import Control.Applicative ( empty, liftA2, liftA3 ) import Control.Exception ( bracket, throwIO ) import Control.Monad ( (>=>), void ) import Data.Bifunctor ( bimap ) +import Data.Fixed (Fixed (MkFixed)) import Data.Foldable ( for_ ) import Data.Int ( Int32, Int64 ) import Data.List ( nub, sort ) @@ -57,6 +58,10 @@ import qualified Hedgehog.Range as Range -- mmorph import Control.Monad.Morph ( hoist ) +-- network-ip +import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) +import Data.DoubleWord (Word128(..)) + -- rel8 import Rel8 ( Result ) import qualified Rel8 @@ -87,9 +92,6 @@ import qualified Database.Postgres.Temp as TmpPostgres -- uuid import qualified Data.UUID --- ip -import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) -import Data.DoubleWord (Word128(..)) main :: IO () main = defaultMain tests @@ -143,6 +145,7 @@ tests = sql "CREATE TABLE test_table ( column1 text not null, column2 bool not null )" sql "CREATE TABLE unique_table ( \"key\" text not null unique, \"value\" text not null )" sql "CREATE SEQUENCE test_seq" + sql "CREATE TYPE composite AS (\"bool\" bool, \"char\" char, \"array\" int4[])" return db @@ -409,21 +412,37 @@ testAp = databasePropertyTest "Cartesian product (<*>)" \transaction -> do sort result === sort (liftA2 (,) rows1 rows2) +data Composite = Composite + { bool :: !Bool + , char :: !Char + , array :: ![Int32] + } + deriving (Eq, Show, Generic) + deriving (Rel8.DBType) via Rel8.Composite Composite + + +instance Rel8.DBComposite Composite where + compositeTypeName = "composite" + compositeFields = Rel8.namesFromLabels + + testDBType :: IO TmpPostgres.DB -> TestTree testDBType getTestDatabase = testGroup "DBType instances" [ dbTypeTest "Bool" Gen.bool , dbTypeTest "ByteString" $ Gen.bytes (Range.linear 0 128) + , dbTypeTest "CalendarDiffTime" genCalendarDiffTime , dbTypeTest "CI Lazy Text" $ mk . Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "CI Text" $ mk <$> Gen.text (Range.linear 0 10) Gen.unicode + , dbTypeTest "Composite" genComposite , dbTypeTest "Day" genDay - , dbTypeTest "Double" $ (/10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) - , dbTypeTest "Float" $ (/10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Double" $ (/ 10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Float" $ (/ 10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Int32" $ Gen.integral @_ @Int32 Range.linearBounded , dbTypeTest "Int64" $ Gen.integral @_ @Int64 Range.linearBounded , dbTypeTest "Lazy ByteString" $ Data.ByteString.Lazy.fromStrict <$> Gen.bytes (Range.linear 0 128) , dbTypeTest "Lazy Text" $ Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "LocalTime" genLocalTime - , dbTypeTest "Scientific" $ (/10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Scientific" $ (/ 10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Text" $ Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "TimeOfDay" genTimeOfDay , dbTypeTest "UTCTime" $ UTCTime <$> genDay <*> genDiffTime @@ -432,24 +451,53 @@ testDBType getTestDatabase = testGroup "DBType instances" ] where - dbTypeTest :: (Eq a, Show a, Rel8.DBType a) => TestName -> Gen a -> TestTree + dbTypeTest :: (Eq a, Show a, Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => TestName -> Gen a -> TestTree dbTypeTest name generator = testGroup name [ databasePropertyTest name (t generator) getTestDatabase , databasePropertyTest ("Maybe " <> name) (t (Gen.maybe generator)) getTestDatabase ] - t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a) + t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => Gen a -> (TestT Transaction () -> PropertyT IO b) -> PropertyT IO b t generator transaction = do x <- forAll generator + y <- forAll generator transaction do [res] <- lift do statement () $ Rel8.select do pure (Rel8.litExpr x) diff res (==) x + [res'] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res' (==) [[x, y]] + [res3] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res3 (==) [[[x, y]]] + res'' <- lift do + statement () $ Rel8.select do + xs <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]) + Rel8.catListTable xs + diff res'' (==) [x, y] +{- + res''' <- lift do + statement () $ Rel8.select do + xss <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]]) + xs <- Rel8.catListTable xss + Rel8.catListTable xs + diff res''' (==) [x, y] +-} + + genComposite :: Gen Composite + genComposite = do + bool <- Gen.bool + char <- Gen.unicode + array <- Gen.list (Range.linear 0 10) (Gen.int32 (Range.linear (-10000) 10000)) + pure Composite {..} genDay :: Gen Day genDay = do @@ -458,6 +506,14 @@ testDBType getTestDatabase = testGroup "DBType instances" day <- Gen.integral (Range.linear 1 31) Gen.just $ pure $ fromGregorianValid year month day + genCalendarDiffTime :: Gen CalendarDiffTime + genCalendarDiffTime = do + -- hardcoded to 0 because Hasql's 'interval' decoder needs to return a + -- CalendarDiffTime for this to be properly round-trippable + months <- pure 0 -- Gen.integral (Range.linear 0 120) + diffTime <- secondsToNominalDiffTime . MkFixed . (* 1000000) <$> Gen.integral (Range.linear 0 2147483647999999) + pure $ CalendarDiffTime months diffTime + genDiffTime :: Gen DiffTime genDiffTime = secondsToDiffTime <$> Gen.integral (Range.linear 0 86401)