Skip to content

Commit

Permalink
Support "rank 2" catListTable (by "parsing" anonymous record)
Browse files Browse the repository at this point in the history
This is another possible "fix" to #168 (as opposed to #242). It doesn't really fix the problem, but it allows us to use two levels of `catListTable` instead of only one. Instead of trying to use Postgres's broken `.f1` syntax, we cast the anonymous record to text, remove the parentheses and quotes and unescape any escaped quotes or backslashes, and then cast the resulting text back to the appropriate type. The reason this only works one level deep is that if the type we cast the text back to is itself an anonymous record, then PostgreSQL doesn't know how to parse the text.

It's kind of ugly and hacky but it does work and otherwise maintains the status quo. Comparison operators on nested lists continue to work as before and we don't need to burden `DBType` with parsing nonsense.
  • Loading branch information
shane-circuithub committed Jun 18, 2023
1 parent 144366d commit 68ebda4
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 21 deletions.
31 changes: 20 additions & 11 deletions src/Rel8/Query/List.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module Rel8.Query.List
where

-- base
import Control.Monad ((>=>))
import Data.Functor.Identity ( runIdentity )
import Data.List.NonEmpty ( NonEmpty )
import Prelude
Expand All @@ -25,10 +26,11 @@ import Rel8.Expr.Opaleye ( mapPrimExpr )
import Rel8.Query ( Query )
import Rel8.Query.Aggregate (aggregate, aggregate1)
import Rel8.Query.Rebind (hrebind, rebind)
import Rel8.Schema.HTable (hfield, hspecs, htabulate)
import Rel8.Schema.HTable.Vectorize ( hunvectorize )
import Rel8.Schema.Null ( Sql, Unnullify )
import Rel8.Schema.Spec ( Spec( Spec, info ) )
import Rel8.Table ( Table, fromColumns )
import Rel8.Table (Table, fromColumns, toColumns)
import Rel8.Table.Aggregate ( listAgg, nonEmptyAgg )
import Rel8.Table.List ( ListTable( ListTable ) )
import Rel8.Table.NonEmpty ( NonEmptyTable( NonEmptyTable ) )
Expand Down Expand Up @@ -76,8 +78,8 @@ someExpr = aggregate1 nonEmptyAggExpr
-- @catListTable@ is an inverse to 'many'.
catListTable :: Table Expr a => ListTable Expr a -> Query a
catListTable (ListTable as) =
fmap fromColumns $ hrebind "unnest" $ runIdentity $
hunvectorize (\Spec {info} -> pure . sunnest info) as
(>>= extract) $ fmap fromColumns $ hrebind "unnest" $ runIdentity $
hunvectorize (\_ -> pure . unnest) as


-- | Expand a 'NonEmptyTable' into a 'Query', where each row in the query is an
Expand All @@ -86,27 +88,34 @@ catListTable (ListTable as) =
-- @catNonEmptyTable@ is an inverse to 'some'.
catNonEmptyTable :: Table Expr a => NonEmptyTable Expr a -> Query a
catNonEmptyTable (NonEmptyTable as) =
fmap fromColumns $ hrebind "unnest" $ runIdentity $
hunvectorize (\Spec {info} -> pure . sunnest info) as
(>>= extract) $ fmap fromColumns $ hrebind "unnest" $ runIdentity $
hunvectorize (\_ -> pure . unnest) as


-- | Expand an expression that contains a list into a 'Query', where each row
-- in the query is an element of the given list.
--
-- @catList@ is an inverse to 'manyExpr'.
catList :: Sql DBType a => Expr [a] -> Query (Expr a)
catList = rebind "unnest" . sunnest typeInformation
catList = rebind "unnest" . unnest >=> extract


-- | Expand an expression that contains a non-empty list into a 'Query', where
-- each row in the query is an element of the given list.
--
-- @catNonEmpty@ is an inverse to 'someExpr'.
catNonEmpty :: Sql DBType a => Expr (NonEmpty a) -> Query (Expr a)
catNonEmpty = rebind "unnest" . sunnest typeInformation
catNonEmpty = rebind "unnest" . unnest >=> extract


sunnest :: TypeInformation (Unnullify a) -> Expr (list a) -> Expr a
sunnest info = mapPrimExpr $
extractArrayElement info .
Opaleye.UnExpr (Opaleye.UnOpOther "UNNEST")
unnest :: Expr (list a) -> Expr a
unnest = mapPrimExpr $ Opaleye.UnExpr (Opaleye.UnOpOther "UNNEST")


extract :: Table Expr a => a -> Query a
extract = rebind "extract" . fromColumns . go . toColumns
where
go as = htabulate $ \field ->
case hfield as field of
a -> case hfield hspecs field of
Spec {info} -> mapPrimExpr (extractArrayElement info) a
31 changes: 30 additions & 1 deletion src/Rel8/Type/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
{-# language OverloadedStrings #-}
{-# language TypeApplications #-}
{-# language ViewPatterns #-}

module Rel8.Type.Array
Expand Down Expand Up @@ -96,5 +97,33 @@ encodeArrayElement info

extractArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr
extractArrayElement info
| isArray info = flip Opaleye.CompositeExpr "f1"
| isArray info = extract
| otherwise = id
where
extract input = cast unrow
where
string = Opaleye.ConstExpr . Opaleye.StringLit
int = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger @Int
minus a b = Opaleye.BinExpr (Opaleye.:-) a b
len = Opaleye.FunExpr "length" . pure
substr s a b = Opaleye.FunExpr "substr" [s, a, b]
cast = Opaleye.CastExpr (typeName info)
text = Opaleye.CastExpr "text" input
unrow =
Opaleye.CaseExpr
[ (quoted, unquote)
]
unparen
where
quoted = Opaleye.BinExpr Opaleye.OpLike text pattern
where
pattern = string "(\"%\")"
unparen = unwrap 1
unwrap n = substr text (int (1 + n)) (minus (len text) (int (n * 2)))
unquote = unescape '"' $ unescape '\\' $ unwrap 2
where
unescape char a =
Opaleye.FunExpr "replace" [a, pattern, replacement]
where
pattern = string [char, char]
replacement = string [char]
72 changes: 63 additions & 9 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{-# language BlockArguments #-}
{-# language DeriveAnyClass #-}
{-# language DeriveGeneric #-}
{-# language DerivingStrategies #-}
{-# language DerivingVia #-}
{-# language FlexibleContexts #-}
{-# language FlexibleInstances #-}
{-# language MonoLocalBinds #-}
Expand All @@ -23,6 +23,7 @@ import Control.Applicative ( empty, liftA2, liftA3 )
import Control.Exception ( bracket, throwIO )
import Control.Monad ( (>=>), void )
import Data.Bifunctor ( bimap )
import Data.Fixed (Fixed (MkFixed))
import Data.Foldable ( for_ )
import Data.Int ( Int32, Int64 )
import Data.List ( nub, sort )
Expand Down Expand Up @@ -57,6 +58,10 @@ import qualified Hedgehog.Range as Range
-- mmorph
import Control.Monad.Morph ( hoist )

-- network-ip
import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr)
import Data.DoubleWord (Word128(..))

-- rel8
import Rel8 ( Result )
import qualified Rel8
Expand Down Expand Up @@ -87,9 +92,6 @@ import qualified Database.Postgres.Temp as TmpPostgres
-- uuid
import qualified Data.UUID

-- ip
import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr)
import Data.DoubleWord (Word128(..))

main :: IO ()
main = defaultMain tests
Expand Down Expand Up @@ -143,6 +145,7 @@ tests =
sql "CREATE TABLE test_table ( column1 text not null, column2 bool not null )"
sql "CREATE TABLE unique_table ( \"key\" text not null unique, \"value\" text not null )"
sql "CREATE SEQUENCE test_seq"
sql "CREATE TYPE composite AS (\"bool\" bool, \"char\" char, \"array\" int4[])"

return db

Expand Down Expand Up @@ -409,21 +412,37 @@ testAp = databasePropertyTest "Cartesian product (<*>)" \transaction -> do
sort result === sort (liftA2 (,) rows1 rows2)


data Composite = Composite
{ bool :: !Bool
, char :: !Char
, array :: ![Int32]
}
deriving (Eq, Show, Generic)
deriving (Rel8.DBType) via Rel8.Composite Composite


instance Rel8.DBComposite Composite where
compositeTypeName = "composite"
compositeFields = Rel8.namesFromLabels


testDBType :: IO TmpPostgres.DB -> TestTree
testDBType getTestDatabase = testGroup "DBType instances"
[ dbTypeTest "Bool" Gen.bool
, dbTypeTest "ByteString" $ Gen.bytes (Range.linear 0 128)
, dbTypeTest "CalendarDiffTime" genCalendarDiffTime
, dbTypeTest "CI Lazy Text" $ mk . Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode
, dbTypeTest "CI Text" $ mk <$> Gen.text (Range.linear 0 10) Gen.unicode
, dbTypeTest "Composite" genComposite
, dbTypeTest "Day" genDay
, dbTypeTest "Double" $ (/10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100)
, dbTypeTest "Float" $ (/10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100)
, dbTypeTest "Double" $ (/ 10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100)
, dbTypeTest "Float" $ (/ 10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100)
, dbTypeTest "Int32" $ Gen.integral @_ @Int32 Range.linearBounded
, dbTypeTest "Int64" $ Gen.integral @_ @Int64 Range.linearBounded
, dbTypeTest "Lazy ByteString" $ Data.ByteString.Lazy.fromStrict <$> Gen.bytes (Range.linear 0 128)
, dbTypeTest "Lazy Text" $ Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode
, dbTypeTest "LocalTime" genLocalTime
, dbTypeTest "Scientific" $ (/10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100)
, dbTypeTest "Scientific" $ (/ 10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100)
, dbTypeTest "Text" $ Gen.text (Range.linear 0 10) Gen.unicode
, dbTypeTest "TimeOfDay" genTimeOfDay
, dbTypeTest "UTCTime" $ UTCTime <$> genDay <*> genDiffTime
Expand All @@ -432,24 +451,51 @@ testDBType getTestDatabase = testGroup "DBType instances"
]

where
dbTypeTest :: (Eq a, Show a, Rel8.DBType a) => TestName -> Gen a -> TestTree
dbTypeTest :: (Eq a, Show a, Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => TestName -> Gen a -> TestTree
dbTypeTest name generator = testGroup name
[ databasePropertyTest name (t generator) getTestDatabase
, databasePropertyTest ("Maybe " <> name) (t (Gen.maybe generator)) getTestDatabase
]

t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a)
t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a)
=> Gen a
-> (TestT Transaction () -> PropertyT IO b)
-> PropertyT IO b
t generator transaction = do
x <- forAll generator
y <- forAll generator

transaction do
[res] <- lift do
statement () $ Rel8.select do
pure (Rel8.litExpr x)
diff res (==) x
[res'] <- lift do
statement () $ Rel8.select $ Rel8.many $ Rel8.many do
Rel8.values [Rel8.litExpr x, Rel8.litExpr y]
diff res' (==) [[x, y]]
[res3] <- lift do
statement () $ Rel8.select $ Rel8.many $ Rel8.many $ Rel8.many do
Rel8.values [Rel8.litExpr x, Rel8.litExpr y]
diff res3 (==) [[[x, y]]]
res'' <- lift do
statement () $ Rel8.select do
xs <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]])
Rel8.catListTable xs
diff res'' (==) [x, y]
res''' <- lift do
statement () $ Rel8.select do
xss <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]])
xs <- Rel8.catListTable xss
Rel8.catListTable xs
diff res''' (==) [x, y]

genComposite :: Gen Composite
genComposite = do
bool <- Gen.bool
char <- Gen.unicode
array <- Gen.list (Range.linear 0 10) (Gen.int32 (Range.linear (-10000) 10000))
pure Composite {..}

genDay :: Gen Day
genDay = do
Expand All @@ -458,6 +504,14 @@ testDBType getTestDatabase = testGroup "DBType instances"
day <- Gen.integral (Range.linear 1 31)
Gen.just $ pure $ fromGregorianValid year month day

genCalendarDiffTime :: Gen CalendarDiffTime
genCalendarDiffTime = do
-- hardcoded to 0 because Hasql's 'interval' decoder needs to return a
-- CalendarDiffTime for this to be properly round-trippable
months <- pure 0 -- Gen.integral (Range.linear 0 120)
diffTime <- secondsToNominalDiffTime . MkFixed . (* 1000000) <$> Gen.integral (Range.linear 0 2147483647999999)
pure $ CalendarDiffTime months diffTime

genDiffTime :: Gen DiffTime
genDiffTime = secondsToDiffTime <$> Gen.integral (Range.linear 0 86401)

Expand Down

0 comments on commit 68ebda4

Please sign in to comment.