From 3e2b1bef84c3db321f914aacaa76d78d0ceaab3c Mon Sep 17 00:00:00 2001 From: Shane O'Brien Date: Sat, 17 Jun 2023 11:58:11 +0100 Subject: [PATCH] Support nested `catListTable` (by represented nested arrays as text) This is one possible "fix" to #168. With this `catListTable` arbitrarily deep trees of `ListTable`s. It comes at a relatively high cost, however. Currently we represent nested arrays with anonymous records. This works reasonably well, except that we can't extract the field from the anonymous record when we need it (PostgreSQL [theoretically](https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.16.5.6) suports `.f1` syntax since PG13 but it only works in very limited situations). But it does mean we can decode the results using Hasql's binary decoders, and ordering works how we expect ('array[row(array[9])] < array[row(array[10])]'. What this PR does is instead represent nested arrays as text. To be able to decoder this, we need each 'DBType' to supply a text parser in addition to a binary decoder. It also means that ordering is no longer intuitive, because `array[array[9]::text] > array[array[10]::text]`. However, it does mean we can nest `catListTable`s to our heart's content and it will always just work. --- rel8.cabal | 8 ++ src/Rel8/Expr/Serialize.hs | 7 +- src/Rel8/Type.hs | 173 ++++++++++++++++++++++++----- src/Rel8/Type/Array.hs | 78 ++++++++++--- src/Rel8/Type/Composite.hs | 70 +++++++++++- src/Rel8/Type/Decoder.hs | 64 +++++++++++ src/Rel8/Type/Enum.hs | 17 ++- src/Rel8/Type/Information.hs | 12 +- src/Rel8/Type/JSONBEncoded.hs | 17 ++- src/Rel8/Type/Parser.hs | 26 +++++ src/Rel8/Type/Parser/ByteString.hs | 54 +++++++++ src/Rel8/Type/Parser/Time.hs | 156 ++++++++++++++++++++++++++ tests/Main.hs | 72 ++++++++++-- 13 files changed, 679 insertions(+), 75 deletions(-) create mode 100644 src/Rel8/Type/Decoder.hs create mode 100644 src/Rel8/Type/Parser.hs create mode 100644 src/Rel8/Type/Parser/ByteString.hs create mode 100644 src/Rel8/Type/Parser/Time.hs diff --git a/rel8.cabal b/rel8.cabal index f1c1f14e..7c126f3e 100644 --- a/rel8.cabal +++ b/rel8.cabal @@ -20,7 +20,9 @@ source-repository head library build-depends: aeson + , attoparsec , base ^>= 4.14 || ^>=4.15 || ^>=4.16 || ^>=4.17 + , base16 , bifunctors , bytestring , case-insensitive @@ -38,6 +40,8 @@ library , text , these , time + , transformers + , utf8-string , uuid default-language: Haskell2010 @@ -190,6 +194,7 @@ library Rel8.Type Rel8.Type.Array Rel8.Type.Composite + Rel8.Type.Decoder Rel8.Type.Eq Rel8.Type.Enum Rel8.Type.Information @@ -198,6 +203,9 @@ library Rel8.Type.Monoid Rel8.Type.Num Rel8.Type.Ord + Rel8.Type.Parser + Rel8.Type.Parser.ByteString + Rel8.Type.Parser.Time Rel8.Type.ReadShow Rel8.Type.Semigroup Rel8.Type.String diff --git a/src/Rel8/Expr/Serialize.hs b/src/Rel8/Expr/Serialize.hs index a2c66578..5812ad32 100644 --- a/src/Rel8/Expr/Serialize.hs +++ b/src/Rel8/Expr/Serialize.hs @@ -23,6 +23,7 @@ import {-# SOURCE #-} Rel8.Expr ( Expr( Expr ) ) import Rel8.Expr.Opaleye ( scastExpr ) import Rel8.Schema.Null ( Unnullify, Nullity( Null, NotNull ), Sql, nullable ) import Rel8.Type ( DBType, typeInformation ) +import Rel8.Type.Decoder (Decoder (..)) import Rel8.Type.Information ( TypeInformation(..) ) @@ -44,6 +45,6 @@ slitExpr nullity info@TypeInformation {encode} = sparseValue :: Nullity a -> TypeInformation (Unnullify a) -> Hasql.Row a -sparseValue nullity TypeInformation {decode} = case nullity of - Null -> Hasql.column $ Hasql.nullable decode - NotNull -> Hasql.column $ Hasql.nonNullable decode +sparseValue nullity TypeInformation {decode = Decoder {binary}} = case nullity of + Null -> Hasql.column $ Hasql.nullable binary + NotNull -> Hasql.column $ Hasql.nonNullable binary diff --git a/src/Rel8/Type.hs b/src/Rel8/Type.hs index 43fd8398..3d23a614 100644 --- a/src/Rel8/Type.hs +++ b/src/Rel8/Type.hs @@ -1,7 +1,9 @@ +{-# language LambdaCase #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} {-# language MonoLocalBinds #-} {-# language MultiWayIf #-} +{-# language OverloadedStrings #-} {-# language StandaloneKindSignatures #-} {-# language UndecidableInstances #-} @@ -13,15 +15,21 @@ where -- aeson import Data.Aeson ( Value ) import qualified Data.Aeson as Aeson +import qualified Data.Aeson.Parser as Aeson + +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A -- base +import Control.Applicative ((<|>)) import Data.Int ( Int16, Int32, Int64 ) import Data.List.NonEmpty ( NonEmpty ) import Data.Kind ( Constraint, Type ) import Prelude -- bytestring -import Data.ByteString ( ByteString ) +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as Lazy ( ByteString ) import qualified Data.ByteString.Lazy as ByteString ( fromStrict, toStrict ) @@ -32,6 +40,9 @@ import qualified Data.CaseInsensitive as CI -- hasql import qualified Hasql.Decoders as Hasql +-- network-ip +import qualified Network.IP.Addr as IP + -- opaleye import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye import qualified Opaleye.Internal.HaskellDB.Sql.Default as Opaleye ( quote ) @@ -39,7 +50,11 @@ import qualified Opaleye.Internal.HaskellDB.Sql.Default as Opaleye ( quote ) -- rel8 import Rel8.Schema.Null ( NotNull, Sql, nullable ) import Rel8.Type.Array ( listTypeInformation, nonEmptyTypeInformation ) +import Rel8.Type.Decoder ( Decoder(..) ) import Rel8.Type.Information ( TypeInformation(..), mapTypeInformation ) +import Rel8.Type.Parser (parse) +import Rel8.Type.Parser.ByteString (bytestring) +import qualified Rel8.Type.Parser.Time as Time -- scientific import Data.Scientific ( Scientific ) @@ -47,26 +62,28 @@ import Data.Scientific ( Scientific ) -- text import Data.Text ( Text ) import qualified Data.Text as Text +import qualified Data.Text.Encoding as Text (decodeUtf8) import qualified Data.Text.Lazy as Lazy ( Text, unpack ) import qualified Data.Text.Lazy as Text ( fromStrict, toStrict ) import qualified Data.Text.Lazy.Encoding as Lazy ( decodeUtf8 ) -- time -import Data.Time.Calendar ( Day ) -import Data.Time.Clock ( UTCTime ) +import Data.Time.Calendar (Day) +import Data.Time.Clock (UTCTime) import Data.Time.LocalTime - ( CalendarDiffTime( CalendarDiffTime ) + ( CalendarDiffTime (CalendarDiffTime) , LocalTime , TimeOfDay ) -import Data.Time.Format ( formatTime, defaultTimeLocale ) +import Data.Time.Format (formatTime, defaultTimeLocale) + +-- utf8 +import qualified Data.ByteString.UTF8 as UTF8 -- uuid import Data.UUID ( UUID ) import qualified Data.UUID as UUID --- ip -import Network.IP.Addr (NetAddr, IP, printNetAddr) -- | Haskell types that can be represented as expressions in a database. There -- should be an instance of @DBType@ for all column types in your database @@ -85,7 +102,15 @@ class NotNull a => DBType a where instance DBType Bool where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.BoolLit - , decode = Hasql.bool + , decode = + Decoder + { binary = Hasql.bool + , parser = \case + "t" -> pure True + "f" -> pure False + input -> Left $ "bool: bad bool " <> show input + , delimiter = ',' + } , typeName = "bool" } @@ -94,7 +119,14 @@ instance DBType Bool where instance DBType Char where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.StringLit . pure - , decode = Hasql.char + , decode = + Decoder + { binary = Hasql.char + , parser = \input -> case UTF8.uncons input of + Just (char, rest) | BS.null rest -> pure char + _ -> Left $ "char: bad char " <> show input + , delimiter = ',' + } , typeName = "char" } @@ -103,7 +135,12 @@ instance DBType Char where instance DBType Int16 where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger - , decode = Hasql.int2 + , decode = + Decoder + { binary = Hasql.int2 + , parser = parse (A.signed A.decimal) + , delimiter = ',' + } , typeName = "int2" } @@ -112,7 +149,12 @@ instance DBType Int16 where instance DBType Int32 where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger - , decode = Hasql.int4 + , decode = + Decoder + { binary = Hasql.int4 + , parser = parse (A.signed A.decimal) + , delimiter = ',' + } , typeName = "int4" } @@ -121,7 +163,12 @@ instance DBType Int32 where instance DBType Int64 where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger - , decode = Hasql.int8 + , decode = + Decoder + { binary = Hasql.int8 + , parser = parse (A.signed A.decimal) + , delimiter = ',' + } , typeName = "int8" } @@ -134,7 +181,12 @@ instance DBType Float where | isNaN x -> Opaleye.OtherLit "'NaN'" | x == (-1 / 0) -> Opaleye.OtherLit "'-Infinity'" | otherwise -> Opaleye.NumericLit $ realToFrac x - , decode = Hasql.float4 + , decode = + Decoder + { binary = Hasql.float4 + , parser = parse (floating (realToFrac <$> A.double)) + , delimiter = ',' + } , typeName = "float4" } @@ -147,7 +199,12 @@ instance DBType Double where | isNaN x -> Opaleye.OtherLit "'NaN'" | x == (-1 / 0) -> Opaleye.OtherLit "'-Infinity'" | otherwise -> Opaleye.NumericLit $ realToFrac x - , decode = Hasql.float8 + , decode = + Decoder + { binary = Hasql.float8 + , parser = parse (floating A.double) + , delimiter = ',' + } , typeName = "float8" } @@ -156,7 +213,12 @@ instance DBType Double where instance DBType Scientific where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.NumericLit - , decode = Hasql.numeric + , decode = + Decoder + { binary = Hasql.numeric + , parser = parse A.scientific + , delimiter = ',' + } , typeName = "numeric" } @@ -167,7 +229,12 @@ instance DBType UTCTime where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%FT%T%QZ'" - , decode = Hasql.timestamptz + , decode = + Decoder + { binary = Hasql.timestamptz + , parser = parse Time.utcTime + , delimiter = ',' + } , typeName = "timestamptz" } @@ -178,7 +245,12 @@ instance DBType Day where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%F'" - , decode = Hasql.date + , decode = + Decoder + { binary = Hasql.date + , parser = parse Time.day + , delimiter = ',' + } , typeName = "date" } @@ -189,7 +261,12 @@ instance DBType LocalTime where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%FT%T%Q'" - , decode = Hasql.timestamp + , decode = + Decoder + { binary = Hasql.timestamp + , parser = parse Time.localTime + , delimiter = ',' + } , typeName = "timestamp" } @@ -200,7 +277,12 @@ instance DBType TimeOfDay where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%T%Q'" - , decode = Hasql.time + , decode = + Decoder + { binary = Hasql.time + , parser = parse Time.timeOfDay + , delimiter = ',' + } , typeName = "time" } @@ -211,7 +293,12 @@ instance DBType CalendarDiffTime where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%bmon %0Es'" - , decode = CalendarDiffTime 0 . realToFrac <$> Hasql.interval + , decode = + Decoder + { binary = CalendarDiffTime 0 . realToFrac <$> Hasql.interval + , parser = parse Time.calendarDiffTime + , delimiter = ',' + } , typeName = "interval" } @@ -220,7 +307,12 @@ instance DBType CalendarDiffTime where instance DBType Text where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.StringLit . Text.unpack - , decode = Hasql.text + , decode = + Decoder + { binary = Hasql.text + , parser = pure . Text.decodeUtf8 + , delimiter = ',' + } , typeName = "text" } @@ -249,7 +341,12 @@ instance DBType (CI Lazy.Text) where instance DBType ByteString where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.ByteStringLit - , decode = Hasql.bytea + , decode = + Decoder + { binary = Hasql.bytea + , parser = parse bytestring + , delimiter = ',' + } , typeName = "bytea" } @@ -265,7 +362,14 @@ instance DBType Lazy.ByteString where instance DBType UUID where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.StringLit . UUID.toString - , decode = Hasql.uuid + , decode = + Decoder + { binary = Hasql.uuid + , parser = \input -> case UUID.fromASCIIBytes input of + Just a -> pure a + Nothing -> Left $ "uuid: bad UUID " <> show input + , delimiter = ',' + } , typeName = "uuid" } @@ -277,16 +381,27 @@ instance DBType Value where Opaleye.ConstExpr . Opaleye.OtherLit . Opaleye.quote . Lazy.unpack . Lazy.decodeUtf8 . Aeson.encode - , decode = Hasql.jsonb + , decode = + Decoder + { binary = Hasql.jsonb + , parser = parse Aeson.value + , delimiter = ',' + } , typeName = "jsonb" } + -- | Corresponds to @inet@ -instance DBType (NetAddr IP) where +instance DBType (IP.NetAddr IP.IP) where typeInformation = TypeInformation { encode = - Opaleye.ConstExpr . Opaleye.StringLit . printNetAddr - , decode = Hasql.inet + Opaleye.ConstExpr . Opaleye.StringLit . IP.printNetAddr + , decode = + Decoder + { binary = Hasql.inet + , parser = parse IP.netParser + , delimiter = ',' + } , typeName = "inet" } @@ -297,3 +412,7 @@ instance Sql DBType a => DBType [a] where instance Sql DBType a => DBType (NonEmpty a) where typeInformation = nonEmptyTypeInformation nullable typeInformation + + +floating :: Floating a => A.Parser a -> A.Parser a +floating p = p <|> A.signed (1.0 / 0 <$ "Infinity") <|> 0.0 / 0 <$ "NaN" diff --git a/src/Rel8/Type/Array.hs b/src/Rel8/Type/Array.hs index 0b67d66a..f1254f74 100644 --- a/src/Rel8/Type/Array.hs +++ b/src/Rel8/Type/Array.hs @@ -11,11 +11,20 @@ module Rel8.Type.Array ) where +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + -- base -import Data.Foldable ( toList ) +import Control.Applicative ((<|>), many) +import Data.Bifunctor (first) +import Data.Foldable (fold, toList) import Data.List.NonEmpty ( NonEmpty, nonEmpty ) import Prelude hiding ( null, repeat, zipWith ) +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS + -- hasql import qualified Hasql.Decoders as Hasql @@ -24,7 +33,12 @@ import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye -- rel8 import Rel8.Schema.Null ( Unnullify, Nullity( Null, NotNull ) ) +import Rel8.Type.Decoder (Decoder (..), NullableOrNot (..), Parser) import Rel8.Type.Information ( TypeInformation(..), parseTypeInformation ) +import Rel8.Type.Parser (parse) + +-- text +import qualified Data.Text as Text array :: Foldable f @@ -41,11 +55,16 @@ listTypeInformation :: () -> TypeInformation [a] listTypeInformation nullity info@TypeInformation {encode, decode} = TypeInformation - { decode = case nullity of - Null -> - Hasql.listArray (decodeArrayElement info (Hasql.nullable decode)) - NotNull -> - Hasql.listArray (decodeArrayElement info (Hasql.nonNullable decode)) + { decode = + Decoder + { binary = Hasql.listArray $ case nullity of + Null -> Hasql.nullable (decodeArrayElement info decode) + NotNull -> Hasql.nonNullable (decodeArrayElement info decode) + , parser = case nullity of + Null -> arrayParser (Nullable decode) + NotNull -> arrayParser (NonNullable decode) + , delimiter = ',' + } , encode = case nullity of Null -> Opaleye.ArrayExpr . @@ -64,9 +83,9 @@ nonEmptyTypeInformation :: () -> TypeInformation (Unnullify a) -> TypeInformation (NonEmpty a) nonEmptyTypeInformation nullity = - parseTypeInformation parse toList . listTypeInformation nullity + parseTypeInformation fromList toList . listTypeInformation nullity where - parse = maybe (Left message) Right . nonEmpty + fromList = maybe (Left message) Right . nonEmpty message = "failed to decode NonEmptyList: got empty list" @@ -78,23 +97,54 @@ isArray = \case arrayType :: TypeInformation a -> String arrayType info - | isArray info = "record" + | isArray info = "text" | otherwise = typeName info -decodeArrayElement :: TypeInformation a -> Hasql.NullableOrNot Hasql.Value x -> Hasql.NullableOrNot Hasql.Value x +decodeArrayElement :: TypeInformation a -> Decoder x -> Hasql.Value x decodeArrayElement info - | isArray info = Hasql.nonNullable . Hasql.composite . Hasql.field - | otherwise = id + | isArray info = \decoder -> + Hasql.refine (first Text.pack . parser decoder) Hasql.bytea + | otherwise = binary encodeArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr encodeArrayElement info - | isArray info = Opaleye.UnExpr (Opaleye.UnOpOther "ROW") + | isArray info = Opaleye.CastExpr "text" | otherwise = id extractArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr extractArrayElement info - | isArray info = flip Opaleye.CompositeExpr "f1" + | isArray info = Opaleye.CastExpr (typeName info <> "[]") | otherwise = id + + +parseArray :: Char -> ByteString -> Either String [Maybe ByteString] +parseArray delimiter = parse $ do + A.char '{' *> A.sepBy element (A.char delimiter) <* A.char '}' + where + element = null <|> nonNull + where + null = Nothing <$ A.string "NULL" + nonNull = Just <$> (quoted <|> unquoted) + where + unquoted = A.takeWhile1 (A.notInClass (delimiter : "\"{}")) + quoted = A.char '"' *> contents <* A.char '"' + where + contents = fold <$> many (unquote <|> unescape) + where + unquote = A.takeWhile1 (A.notInClass "\"\\") + unescape = A.char '\\' *> do + BS.singleton <$> do + A.char '\\' <|> A.char '"' + + +arrayParser :: NullableOrNot Decoder a -> Parser [a] +arrayParser = \case + Nullable Decoder {parser, delimiter} -> \input -> do + elements <- parseArray delimiter input + traverse (traverse parser) elements + NonNullable Decoder {parser, delimiter} -> \input -> do + elements <- parseArray delimiter input + traverse (maybe (Left "array: unexpected null") parser) elements diff --git a/src/Rel8/Type/Composite.hs b/src/Rel8/Type/Composite.hs index 3482e0b1..eaecb82c 100644 --- a/src/Rel8/Type/Composite.hs +++ b/src/Rel8/Type/Composite.hs @@ -1,9 +1,11 @@ {-# language AllowAmbiguousTypes #-} {-# language BlockArguments #-} {-# language DataKinds #-} +{-# language DisambiguateRecordFields #-} {-# language FlexibleContexts #-} {-# language GADTs #-} {-# language NamedFieldPuns #-} +{-# language OverloadedStrings #-} {-# language ScopedTypeVariables #-} {-# language StandaloneKindSignatures #-} {-# language TypeApplications #-} @@ -18,12 +20,22 @@ module Rel8.Type.Composite ) where +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + -- base -import Data.Functor.Const ( Const( Const ), getConst ) -import Data.Functor.Identity ( Identity( Identity ) ) +import Control.Applicative ((<|>), many, optional) +import Data.Foldable (fold) +import Data.Functor.Const (Const (Const), getConst) +import Data.Functor.Identity (Identity (Identity)) import Data.Kind ( Constraint, Type ) +import Data.List (uncons) import Prelude +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS + -- hasql import qualified Hasql.Decoders as Hasql @@ -45,13 +57,20 @@ import Rel8.Table.Ord ( OrdTable ) import Rel8.Table.Rel8able () import Rel8.Table.Serialize ( litHTable ) import Rel8.Type ( DBType, typeInformation ) +import Rel8.Type.Decoder (Decoder (Decoder), Parser) +import qualified Rel8.Type.Decoder as Decoder import Rel8.Type.Eq ( DBEq ) import Rel8.Type.Information ( TypeInformation(..) ) import Rel8.Type.Ord ( DBOrd, DBMax, DBMin ) +import Rel8.Type.Parser (parse) -- semigroupoids import Data.Functor.Apply ( WrappedApplicative(..) ) +-- transformers +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (StateT (StateT), runStateT) + -- | A deriving-via helper type for column types that store a Haskell product -- type in a single Postgres column using a Postgres composite type. @@ -68,7 +87,12 @@ newtype Composite a = Composite instance DBComposite a => DBType (Composite a) where typeInformation = TypeInformation - { decode = Hasql.composite (Composite . fromResult @_ @(HKD a Expr) <$> decoder) + { decode = + Decoder + { binary = Hasql.composite (Composite . fromResult @_ @(HKD a Expr) <$> decoder) + , parser = fmap (Composite . fromResult @_ @(HKD a Expr)) . parser + , delimiter = ',' + } , encode = encoder . litHTable . toResult @_ @(HKD a Expr) . unComposite , typeName = compositeTypeName @a } @@ -124,8 +148,8 @@ decoder = unwrapApplicative $ htabulateA \field -> case hfield hspecs field of Spec {nullity, info} -> WrapApplicative $ Identity <$> case nullity of - Null -> Hasql.field $ Hasql.nullable $ decode info - NotNull -> Hasql.field $ Hasql.nonNullable $ decode info + Null -> Hasql.field $ Hasql.nullable $ Decoder.binary $ decode info + NotNull -> Hasql.field $ Hasql.nonNullable $ Decoder.binary $ decode info encoder :: HTable t => t Expr -> Opaleye.PrimExpr @@ -133,3 +157,39 @@ encoder a = Opaleye.FunExpr "ROW" exprs where exprs = getConst $ htabulateA \field -> case hfield a field of expr -> Const [toPrimExpr expr] + + +parser :: HTable t => Parser (t Result) +parser input = do + fields <- parseRow input + (a, rest) <- runStateT go fields + case rest of + [] -> pure a + _ -> Left "composite: too many fields" + where + go = htabulateA \field -> do + mbytes <- StateT $ maybe missing pure . uncons + lift $ Identity <$> case hfield hspecs field of + Spec {nullity, info} -> case nullity of + Null -> traverse (Decoder.parser (decode info)) mbytes + NotNull -> case mbytes of + Nothing -> Left "composite: unexpected null" + Just bytes -> Decoder.parser (decode info) bytes + missing = Left "composite: missing fields" + + +parseRow :: ByteString -> Either String [Maybe ByteString] +parseRow = parse $ do + A.char '(' *> A.sepBy element (A.char ',') <* A.char ')' + where + element = optional (quoted <|> unquoted) + where + unquoted = A.takeWhile1 (A.notInClass ",\"()") + quoted = A.char '"' *> contents <* A.char '"' + where + contents = fold <$> many (unquote <|> unescape) + where + unquote = A.takeWhile1 (A.notInClass "\"\\") + unescape = A.char '\\' *> do + BS.singleton <$> do + A.char '\\' <|> A.char '"' diff --git a/src/Rel8/Type/Decoder.hs b/src/Rel8/Type/Decoder.hs new file mode 100644 index 00000000..5322e7c5 --- /dev/null +++ b/src/Rel8/Type/Decoder.hs @@ -0,0 +1,64 @@ +{-# language DerivingStrategies #-} +{-# language DeriveFunctor #-} +{-# language GADTs #-} +{-# language NamedFieldPuns #-} +{-# language StandaloneKindSignatures #-} + +module Rel8.Type.Decoder ( + Decoder (..), + NullableOrNot (..), + Parser, + parseDecoder, +) where + +-- base +import Control.Monad ((>=>)) +import Data.Bifunctor (first) +import Data.Kind (Type) +import Prelude + +-- bytestring +import Data.ByteString (ByteString) + +-- hasql +import qualified Hasql.Decoders as Hasql + +-- text +import qualified Data.Text as Text + + +type Parser :: Type -> Type +type Parser a = ByteString -> Either String a + + +type Decoder :: Type -> Type +data Decoder a = Decoder + { binary :: Hasql.Value a + -- ^ How to deserialize from PostgreSQL's binary format. + , parser :: Parser a + -- ^ How to deserialize from PostgreSQL's text format. + , delimiter :: Char + -- ^ The delimiter that is used in PostgreSQL's text format in arrays of + -- this type (this is almost always ','). + } + deriving stock (Functor) + + +-- | Apply a parser to 'Decoder'. +-- +-- This can be used if the data stored in the database should only be subset of +-- a given 'Decoder'. The parser is applied when deserializing rows +-- returned. +parseDecoder :: (a -> Either String b) -> Decoder a -> Decoder b +parseDecoder f Decoder {binary, parser, delimiter} = + Decoder + { binary = Hasql.refine (first Text.pack . f) binary + , parser = parser >=> f + , delimiter + } + + +type NullableOrNot :: (Type -> Type) -> Type -> Type +data NullableOrNot decoder a where + NonNullable :: decoder a -> NullableOrNot decoder a + Nullable :: decoder a -> NullableOrNot decoder (Maybe a) diff --git a/src/Rel8/Type/Enum.hs b/src/Rel8/Type/Enum.hs index 3324079e..27a67f10 100644 --- a/src/Rel8/Type/Enum.hs +++ b/src/Rel8/Type/Enum.hs @@ -39,12 +39,14 @@ import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye -- rel8 import Rel8.Type ( DBType, typeInformation ) +import Rel8.Type.Decoder (Decoder (..)) import Rel8.Type.Eq ( DBEq ) import Rel8.Type.Information ( TypeInformation(..) ) import Rel8.Type.Ord ( DBOrd, DBMax, DBMin ) -- text -import Data.Text ( pack ) +import Data.Text (pack) +import Data.Text.Encoding (decodeUtf8) -- | A deriving-via helper type for column types that store an \"enum\" type @@ -66,10 +68,15 @@ newtype Enum a = Enum instance DBEnum a => DBType (Enum a) where typeInformation = TypeInformation { decode = - Hasql.enum $ - flip lookup $ - map ((pack . enumValue &&& Enum) . to) $ - genumerate @(Rep a) + let + mapping = (pack . enumValue &&& Enum) . to <$> genumerate @(Rep a) + unrecognised = Left "enum: unrecognised value" + in + Decoder + { binary = Hasql.enum (`lookup` mapping) + , parser = maybe unrecognised pure . (`lookup` mapping) . decodeUtf8 + , delimiter = ',' + } , encode = Opaleye.ConstExpr . Opaleye.StringLit . diff --git a/src/Rel8/Type/Information.hs b/src/Rel8/Type/Information.hs index 47651677..e378d1b3 100644 --- a/src/Rel8/Type/Information.hs +++ b/src/Rel8/Type/Information.hs @@ -10,18 +10,14 @@ module Rel8.Type.Information where -- base -import Data.Bifunctor ( first ) import Data.Kind ( Type ) import Prelude --- hasql -import qualified Hasql.Decoders as Hasql - -- opaleye import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye --- text -import qualified Data.Text as Text +-- rel8 +import Rel8.Type.Decoder (Decoder, parseDecoder) -- | @TypeInformation@ describes how to encode and decode a Haskell type to and @@ -31,7 +27,7 @@ type TypeInformation :: Type -> Type data TypeInformation a = TypeInformation { encode :: a -> Opaleye.PrimExpr -- ^ How to encode a single Haskell value as a SQL expression. - , decode :: Hasql.Value a + , decode :: Decoder a -- ^ How to deserialize a single result back to Haskell. , typeName :: String -- ^ The name of the SQL type. @@ -62,6 +58,6 @@ parseTypeInformation :: () parseTypeInformation to from TypeInformation {encode, decode, typeName} = TypeInformation { encode = encode . from - , decode = Hasql.refine (first Text.pack . to) decode + , decode = parseDecoder to decode , typeName } diff --git a/src/Rel8/Type/JSONBEncoded.hs b/src/Rel8/Type/JSONBEncoded.hs index 7530f0d5..dd776a3e 100644 --- a/src/Rel8/Type/JSONBEncoded.hs +++ b/src/Rel8/Type/JSONBEncoded.hs @@ -1,10 +1,13 @@ {-# language StandaloneKindSignatures #-} -module Rel8.Type.JSONBEncoded ( JSONBEncoded(..) ) where +module Rel8.Type.JSONBEncoded + ( JSONBEncoded(..) + ) +where -- aeson -import Data.Aeson ( FromJSON, ToJSON, parseJSON, toJSON ) -import Data.Aeson.Types ( parseEither ) +import Data.Aeson (FromJSON, ToJSON, eitherDecodeStrict, parseJSON, toJSON) +import Data.Aeson.Types (parseEither) -- base import Data.Bifunctor ( first ) @@ -16,6 +19,7 @@ import qualified Hasql.Decoders as Hasql -- rel8 import Rel8.Type ( DBType(..) ) +import Rel8.Type.Decoder (Decoder (..)) import Rel8.Type.Information ( TypeInformation(..) ) -- text @@ -30,6 +34,11 @@ newtype JSONBEncoded a = JSONBEncoded { fromJSONBEncoded :: a } instance (FromJSON a, ToJSON a) => DBType (JSONBEncoded a) where typeInformation = TypeInformation { encode = encode typeInformation . toJSON . fromJSONBEncoded - , decode = Hasql.refine (first pack . fmap JSONBEncoded . parseEither parseJSON) Hasql.jsonb + , decode = + Decoder + { binary = Hasql.refine (first pack . fmap JSONBEncoded . parseEither parseJSON) Hasql.jsonb + , parser = fmap JSONBEncoded . eitherDecodeStrict + , delimiter = ',' + } , typeName = "jsonb" } diff --git a/src/Rel8/Type/Parser.hs b/src/Rel8/Type/Parser.hs new file mode 100644 index 00000000..3b6fbb12 --- /dev/null +++ b/src/Rel8/Type/Parser.hs @@ -0,0 +1,26 @@ +module Rel8.Type.Parser + ( parse + ) +where + +-- attoparsec +import qualified Data.Attoparsec.ByteString as A + +-- base +import Control.Applicative ((<|>)) +import Control.Monad (unless) +import Prelude + +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS + +-- utf8-string +import qualified Data.ByteString.UTF8 as UTF8 + + +parse :: A.Parser a -> ByteString -> Either String a +parse parser = do + A.parseOnly (parser <* A.endOfInput <|> debug) + where + debug = A.takeByteString >>= fail . UTF8.toString diff --git a/src/Rel8/Type/Parser/ByteString.hs b/src/Rel8/Type/Parser/ByteString.hs new file mode 100644 index 00000000..1f2b4f64 --- /dev/null +++ b/src/Rel8/Type/Parser/ByteString.hs @@ -0,0 +1,54 @@ +{-# language OverloadedStrings #-} +{-# language TypeApplications #-} + +module Rel8.Type.Parser.ByteString + ( bytestring + ) +where + +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + +-- base +import Control.Applicative ((<|>), many) +import Control.Monad (guard) +import Data.Bits ((.|.), shiftL) +import Data.Char (isOctDigit) +import Data.Foldable (fold) +import Prelude + +-- base16 +import Data.ByteString.Base16 (decodeBase16) + +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS + +-- text +import qualified Data.Text as Text + + +bytestring :: A.Parser ByteString +bytestring = hex <|> escape + where + hex = do + digits <- "\\x" *> A.takeByteString + either (fail . Text.unpack) pure $ decodeBase16 digits + escape = fold <$> many (escaped <|> unescaped) + where + unescaped = A.takeWhile1 (/= '\\') + escaped = BS.singleton <$> (backslash <|> octal) + where + backslash = '\\' <$ "\\\\" + octal = do + a <- A.char '\\' *> digit + b <- digit + c <- digit + let + result = a `shiftL` 6 .|. b `shiftL` 3 .|. c + guard $ result < 0o400 + pure $ toEnum result + where + digit = do + c <- A.satisfy isOctDigit + pure $ fromEnum c - fromEnum '0' diff --git a/src/Rel8/Type/Parser/Time.hs b/src/Rel8/Type/Parser/Time.hs new file mode 100644 index 00000000..abc2cb67 --- /dev/null +++ b/src/Rel8/Type/Parser/Time.hs @@ -0,0 +1,156 @@ +{-# language OverloadedStrings #-} +{-# language TypeApplications #-} + +module Rel8.Type.Parser.Time + ( calendarDiffTime + , day + , localTime + , timeOfDay + , utcTime + ) +where + +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + +-- base +import Control.Applicative ((<|>), optional) +import Data.Bits ((.&.)) +import Data.Bool (bool) +import Data.Fixed (Fixed (MkFixed), Pico, divMod') +import Data.Functor (void) +import Data.Int (Int64) +import Prelude + +-- bytestring +import qualified Data.ByteString as BS + +-- time +import Data.Time.Calendar (Day, addDays, fromGregorianValid) +import Data.Time.Clock (DiffTime, UTCTime (UTCTime)) +import Data.Time.Format.ISO8601 (iso8601ParseM) +import Data.Time.LocalTime + ( CalendarDiffTime (CalendarDiffTime) + , LocalTime (LocalTime) + , TimeOfDay (TimeOfDay) + , sinceMidnight + ) + +-- utf8 +import qualified Data.ByteString.UTF8 as UTF8 + + +day :: A.Parser Day +day = do + y <- A.decimal <* A.char '-' + m <- twoDigits <* A.char '-' + d <- twoDigits + maybe (fail "Day: invalid date") pure $ fromGregorianValid y m d + + +timeOfDay :: A.Parser TimeOfDay +timeOfDay = do + h <- twoDigits + m <- A.char ':' *> twoDigits + s <- A.char ':' *> secondsParser + if h < 24 && m < 60 && s <= 60 + then pure $ TimeOfDay h m s + else fail "TimeOfDay: invalid time" + + +localTime :: A.Parser LocalTime +localTime = LocalTime <$> day <* separator <*> timeOfDay + where + separator = A.char ' ' <|> A.char 'T' + + +utcTime :: A.Parser UTCTime +utcTime = do + LocalTime date time <- localTime + tz <- timeZone + let + (days, time') = (sinceMidnight time + tz) `divMod'` oneDay + where + oneDay = 24 * 60 * 60 + date' = addDays days date + pure $ UTCTime date' time' + + +calendarDiffTime :: A.Parser CalendarDiffTime +calendarDiffTime = iso8601 <|> postgres + where + iso8601 = A.takeByteString >>= iso8601ParseM . UTF8.toString + at = optional (A.char '@') *> A.skipSpace + plural unit = A.skipSpace <* (unit <* optional "s") <* A.skipSpace + parseMonths = sql <|> postgresql + where + sql = A.signed $ do + years <- A.decimal <* A.char '-' + months <- A.decimal <* A.skipSpace + pure $ years * 12 + months + postgresql = do + at + years <- A.signed A.decimal <* plural "year" <|> pure 0 + months <- A.signed A.decimal <* plural "mon" <|> pure 0 + pure $ years * 12 + months + parseTime = (+) <$> parseDays <*> time + where + time = realToFrac <$> (sql <|> postgresql) + where + sql = A.signed $ do + h <- A.signed A.decimal <* A.char ':' + m <- twoDigits <* A.char ':' + s <- secondsParser + pure $ fromIntegral (((h * 60) + m) * 60) + s + postgresql = do + h <- A.signed A.decimal <* plural "hour" <|> pure 0 + m <- A.signed A.decimal <* plural "min" <|> pure 0 + s <- secondsParser <* plural "sec" <|> pure 0 + pure $ fromIntegral @Int (((h * 60) + m) * 60) + s + parseDays = do + days <- A.signed A.decimal <* (plural "days" <|> skipSpace1) <|> pure 0 + pure $ fromIntegral @Int days * 24 * 60 * 60 + postgres = do + months <- parseMonths + time <- parseTime + ago <- (True <$ (A.skipSpace *> "ago")) <|> pure False + pure $ CalendarDiffTime (bool id negate ago months) (bool id negate ago time) + + +secondsParser :: A.Parser Pico +secondsParser = do + integral <- twoDigits + mfractional <- optional (A.char '.' *> A.takeWhile1 A.isDigit) + pure $ case mfractional of + Nothing -> fromIntegral integral + Just fractional -> parseFraction (fromIntegral integral) fractional + where + parseFraction integral digits = MkFixed (fromIntegral (n * 10 ^ e)) + where + e = max 0 (12 - BS.length digits) + n = BS.foldl' go (integral :: Int64) (BS.take 12 digits) + where + go acc digit = 10 * acc + fromIntegral (fromEnum digit .&. 0xf) + + +twoDigits :: A.Parser Int +twoDigits = do + u <- A.digit + l <- A.digit + pure $ fromEnum u .&. 0xf * 10 + fromEnum l .&. 0xf + + +timeZone :: A.Parser DiffTime +timeZone = 0 <$ A.char 'Z' <|> diffTime + + +diffTime :: A.Parser DiffTime +diffTime = A.signed $ do + h <- twoDigits + m <- A.char ':' *> twoDigits <|> pure 0 + s <- A.char ':' *> secondsParser <|> pure 0 + pure $ sinceMidnight $ TimeOfDay h m s + + +skipSpace1 :: A.Parser () +skipSpace1 = void $ A.takeWhile1 A.isSpace diff --git a/tests/Main.hs b/tests/Main.hs index e4b43491..85a3e994 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -2,7 +2,7 @@ {-# language BlockArguments #-} {-# language DeriveAnyClass #-} {-# language DeriveGeneric #-} -{-# language DerivingStrategies #-} +{-# language DerivingVia #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} {-# language MonoLocalBinds #-} @@ -23,6 +23,7 @@ import Control.Applicative ( empty, liftA2, liftA3 ) import Control.Exception ( bracket, throwIO ) import Control.Monad ( (>=>), void ) import Data.Bifunctor ( bimap ) +import Data.Fixed (Fixed (MkFixed)) import Data.Foldable ( for_ ) import Data.Int ( Int32, Int64 ) import Data.List ( nub, sort ) @@ -57,6 +58,10 @@ import qualified Hedgehog.Range as Range -- mmorph import Control.Monad.Morph ( hoist ) +-- network-ip +import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) +import Data.DoubleWord (Word128(..)) + -- rel8 import Rel8 ( Result ) import qualified Rel8 @@ -87,9 +92,6 @@ import qualified Database.Postgres.Temp as TmpPostgres -- uuid import qualified Data.UUID --- ip -import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) -import Data.DoubleWord (Word128(..)) main :: IO () main = defaultMain tests @@ -143,6 +145,7 @@ tests = sql "CREATE TABLE test_table ( column1 text not null, column2 bool not null )" sql "CREATE TABLE unique_table ( \"key\" text not null unique, \"value\" text not null )" sql "CREATE SEQUENCE test_seq" + sql "CREATE TYPE composite AS (\"bool\" bool, \"char\" char, \"array\" int4[])" return db @@ -409,21 +412,37 @@ testAp = databasePropertyTest "Cartesian product (<*>)" \transaction -> do sort result === sort (liftA2 (,) rows1 rows2) +data Composite = Composite + { bool :: !Bool + , char :: !Char + , array :: ![Int32] + } + deriving (Eq, Show, Generic) + deriving (Rel8.DBType) via Rel8.Composite Composite + + +instance Rel8.DBComposite Composite where + compositeTypeName = "composite" + compositeFields = Rel8.namesFromLabels + + testDBType :: IO TmpPostgres.DB -> TestTree testDBType getTestDatabase = testGroup "DBType instances" [ dbTypeTest "Bool" Gen.bool , dbTypeTest "ByteString" $ Gen.bytes (Range.linear 0 128) + , dbTypeTest "CalendarDiffTime" genCalendarDiffTime , dbTypeTest "CI Lazy Text" $ mk . Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "CI Text" $ mk <$> Gen.text (Range.linear 0 10) Gen.unicode + , dbTypeTest "Composite" genComposite , dbTypeTest "Day" genDay - , dbTypeTest "Double" $ (/10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) - , dbTypeTest "Float" $ (/10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Double" $ (/ 10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Float" $ (/ 10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Int32" $ Gen.integral @_ @Int32 Range.linearBounded , dbTypeTest "Int64" $ Gen.integral @_ @Int64 Range.linearBounded , dbTypeTest "Lazy ByteString" $ Data.ByteString.Lazy.fromStrict <$> Gen.bytes (Range.linear 0 128) , dbTypeTest "Lazy Text" $ Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "LocalTime" genLocalTime - , dbTypeTest "Scientific" $ (/10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Scientific" $ (/ 10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Text" $ Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "TimeOfDay" genTimeOfDay , dbTypeTest "UTCTime" $ UTCTime <$> genDay <*> genDiffTime @@ -432,24 +451,51 @@ testDBType getTestDatabase = testGroup "DBType instances" ] where - dbTypeTest :: (Eq a, Show a, Rel8.DBType a) => TestName -> Gen a -> TestTree + dbTypeTest :: (Eq a, Show a, Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => TestName -> Gen a -> TestTree dbTypeTest name generator = testGroup name [ databasePropertyTest name (t generator) getTestDatabase , databasePropertyTest ("Maybe " <> name) (t (Gen.maybe generator)) getTestDatabase ] - t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a) + t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => Gen a -> (TestT Transaction () -> PropertyT IO b) -> PropertyT IO b t generator transaction = do x <- forAll generator + y <- forAll generator transaction do [res] <- lift do statement () $ Rel8.select do pure (Rel8.litExpr x) diff res (==) x + [res'] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res' (==) [[x, y]] + [res3] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res3 (==) [[[x, y]]] + res'' <- lift do + statement () $ Rel8.select do + xs <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]) + Rel8.catListTable xs + diff res'' (==) [x, y] + res''' <- lift do + statement () $ Rel8.select do + xss <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]]) + xs <- Rel8.catListTable xss + Rel8.catListTable xs + diff res''' (==) [x, y] + + genComposite :: Gen Composite + genComposite = do + bool <- Gen.bool + char <- Gen.unicode + array <- Gen.list (Range.linear 0 10) (Gen.int32 (Range.linear (-10000) 10000)) + pure Composite {..} genDay :: Gen Day genDay = do @@ -458,6 +504,14 @@ testDBType getTestDatabase = testGroup "DBType instances" day <- Gen.integral (Range.linear 1 31) Gen.just $ pure $ fromGregorianValid year month day + genCalendarDiffTime :: Gen CalendarDiffTime + genCalendarDiffTime = do + -- hardcoded to 0 because Hasql's 'interval' decoder needs to return a + -- CalendarDiffTime for this to be properly round-trippable + months <- pure 0 -- Gen.integral (Range.linear 0 120) + diffTime <- secondsToNominalDiffTime . MkFixed . (* 1000000) <$> Gen.integral (Range.linear 0 2147483647999999) + pure $ CalendarDiffTime months diffTime + genDiffTime :: Gen DiffTime genDiffTime = secondsToDiffTime <$> Gen.integral (Range.linear 0 86401)