diff --git a/rel8.cabal b/rel8.cabal index f1c1f14e..7c126f3e 100644 --- a/rel8.cabal +++ b/rel8.cabal @@ -20,7 +20,9 @@ source-repository head library build-depends: aeson + , attoparsec , base ^>= 4.14 || ^>=4.15 || ^>=4.16 || ^>=4.17 + , base16 , bifunctors , bytestring , case-insensitive @@ -38,6 +40,8 @@ library , text , these , time + , transformers + , utf8-string , uuid default-language: Haskell2010 @@ -190,6 +194,7 @@ library Rel8.Type Rel8.Type.Array Rel8.Type.Composite + Rel8.Type.Decoder Rel8.Type.Eq Rel8.Type.Enum Rel8.Type.Information @@ -198,6 +203,9 @@ library Rel8.Type.Monoid Rel8.Type.Num Rel8.Type.Ord + Rel8.Type.Parser + Rel8.Type.Parser.ByteString + Rel8.Type.Parser.Time Rel8.Type.ReadShow Rel8.Type.Semigroup Rel8.Type.String diff --git a/src/Rel8/Expr/Serialize.hs b/src/Rel8/Expr/Serialize.hs index a2c66578..5812ad32 100644 --- a/src/Rel8/Expr/Serialize.hs +++ b/src/Rel8/Expr/Serialize.hs @@ -23,6 +23,7 @@ import {-# SOURCE #-} Rel8.Expr ( Expr( Expr ) ) import Rel8.Expr.Opaleye ( scastExpr ) import Rel8.Schema.Null ( Unnullify, Nullity( Null, NotNull ), Sql, nullable ) import Rel8.Type ( DBType, typeInformation ) +import Rel8.Type.Decoder (Decoder (..)) import Rel8.Type.Information ( TypeInformation(..) ) @@ -44,6 +45,6 @@ slitExpr nullity info@TypeInformation {encode} = sparseValue :: Nullity a -> TypeInformation (Unnullify a) -> Hasql.Row a -sparseValue nullity TypeInformation {decode} = case nullity of - Null -> Hasql.column $ Hasql.nullable decode - NotNull -> Hasql.column $ Hasql.nonNullable decode +sparseValue nullity TypeInformation {decode = Decoder {binary}} = case nullity of + Null -> Hasql.column $ Hasql.nullable binary + NotNull -> Hasql.column $ Hasql.nonNullable binary diff --git a/src/Rel8/Type.hs b/src/Rel8/Type.hs index 43fd8398..3d23a614 100644 --- a/src/Rel8/Type.hs +++ b/src/Rel8/Type.hs @@ -1,7 +1,9 @@ +{-# language LambdaCase #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} {-# language MonoLocalBinds #-} {-# language MultiWayIf #-} +{-# language OverloadedStrings #-} {-# language StandaloneKindSignatures #-} {-# language UndecidableInstances #-} @@ -13,15 +15,21 @@ where -- aeson import Data.Aeson ( Value ) import qualified Data.Aeson as Aeson +import qualified Data.Aeson.Parser as Aeson + +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A -- base +import Control.Applicative ((<|>)) import Data.Int ( Int16, Int32, Int64 ) import Data.List.NonEmpty ( NonEmpty ) import Data.Kind ( Constraint, Type ) import Prelude -- bytestring -import Data.ByteString ( ByteString ) +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as Lazy ( ByteString ) import qualified Data.ByteString.Lazy as ByteString ( fromStrict, toStrict ) @@ -32,6 +40,9 @@ import qualified Data.CaseInsensitive as CI -- hasql import qualified Hasql.Decoders as Hasql +-- network-ip +import qualified Network.IP.Addr as IP + -- opaleye import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye import qualified Opaleye.Internal.HaskellDB.Sql.Default as Opaleye ( quote ) @@ -39,7 +50,11 @@ import qualified Opaleye.Internal.HaskellDB.Sql.Default as Opaleye ( quote ) -- rel8 import Rel8.Schema.Null ( NotNull, Sql, nullable ) import Rel8.Type.Array ( listTypeInformation, nonEmptyTypeInformation ) +import Rel8.Type.Decoder ( Decoder(..) ) import Rel8.Type.Information ( TypeInformation(..), mapTypeInformation ) +import Rel8.Type.Parser (parse) +import Rel8.Type.Parser.ByteString (bytestring) +import qualified Rel8.Type.Parser.Time as Time -- scientific import Data.Scientific ( Scientific ) @@ -47,26 +62,28 @@ import Data.Scientific ( Scientific ) -- text import Data.Text ( Text ) import qualified Data.Text as Text +import qualified Data.Text.Encoding as Text (decodeUtf8) import qualified Data.Text.Lazy as Lazy ( Text, unpack ) import qualified Data.Text.Lazy as Text ( fromStrict, toStrict ) import qualified Data.Text.Lazy.Encoding as Lazy ( decodeUtf8 ) -- time -import Data.Time.Calendar ( Day ) -import Data.Time.Clock ( UTCTime ) +import Data.Time.Calendar (Day) +import Data.Time.Clock (UTCTime) import Data.Time.LocalTime - ( CalendarDiffTime( CalendarDiffTime ) + ( CalendarDiffTime (CalendarDiffTime) , LocalTime , TimeOfDay ) -import Data.Time.Format ( formatTime, defaultTimeLocale ) +import Data.Time.Format (formatTime, defaultTimeLocale) + +-- utf8 +import qualified Data.ByteString.UTF8 as UTF8 -- uuid import Data.UUID ( UUID ) import qualified Data.UUID as UUID --- ip -import Network.IP.Addr (NetAddr, IP, printNetAddr) -- | Haskell types that can be represented as expressions in a database. There -- should be an instance of @DBType@ for all column types in your database @@ -85,7 +102,15 @@ class NotNull a => DBType a where instance DBType Bool where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.BoolLit - , decode = Hasql.bool + , decode = + Decoder + { binary = Hasql.bool + , parser = \case + "t" -> pure True + "f" -> pure False + input -> Left $ "bool: bad bool " <> show input + , delimiter = ',' + } , typeName = "bool" } @@ -94,7 +119,14 @@ instance DBType Bool where instance DBType Char where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.StringLit . pure - , decode = Hasql.char + , decode = + Decoder + { binary = Hasql.char + , parser = \input -> case UTF8.uncons input of + Just (char, rest) | BS.null rest -> pure char + _ -> Left $ "char: bad char " <> show input + , delimiter = ',' + } , typeName = "char" } @@ -103,7 +135,12 @@ instance DBType Char where instance DBType Int16 where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger - , decode = Hasql.int2 + , decode = + Decoder + { binary = Hasql.int2 + , parser = parse (A.signed A.decimal) + , delimiter = ',' + } , typeName = "int2" } @@ -112,7 +149,12 @@ instance DBType Int16 where instance DBType Int32 where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger - , decode = Hasql.int4 + , decode = + Decoder + { binary = Hasql.int4 + , parser = parse (A.signed A.decimal) + , delimiter = ',' + } , typeName = "int4" } @@ -121,7 +163,12 @@ instance DBType Int32 where instance DBType Int64 where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.IntegerLit . toInteger - , decode = Hasql.int8 + , decode = + Decoder + { binary = Hasql.int8 + , parser = parse (A.signed A.decimal) + , delimiter = ',' + } , typeName = "int8" } @@ -134,7 +181,12 @@ instance DBType Float where | isNaN x -> Opaleye.OtherLit "'NaN'" | x == (-1 / 0) -> Opaleye.OtherLit "'-Infinity'" | otherwise -> Opaleye.NumericLit $ realToFrac x - , decode = Hasql.float4 + , decode = + Decoder + { binary = Hasql.float4 + , parser = parse (floating (realToFrac <$> A.double)) + , delimiter = ',' + } , typeName = "float4" } @@ -147,7 +199,12 @@ instance DBType Double where | isNaN x -> Opaleye.OtherLit "'NaN'" | x == (-1 / 0) -> Opaleye.OtherLit "'-Infinity'" | otherwise -> Opaleye.NumericLit $ realToFrac x - , decode = Hasql.float8 + , decode = + Decoder + { binary = Hasql.float8 + , parser = parse (floating A.double) + , delimiter = ',' + } , typeName = "float8" } @@ -156,7 +213,12 @@ instance DBType Double where instance DBType Scientific where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.NumericLit - , decode = Hasql.numeric + , decode = + Decoder + { binary = Hasql.numeric + , parser = parse A.scientific + , delimiter = ',' + } , typeName = "numeric" } @@ -167,7 +229,12 @@ instance DBType UTCTime where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%FT%T%QZ'" - , decode = Hasql.timestamptz + , decode = + Decoder + { binary = Hasql.timestamptz + , parser = parse Time.utcTime + , delimiter = ',' + } , typeName = "timestamptz" } @@ -178,7 +245,12 @@ instance DBType Day where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%F'" - , decode = Hasql.date + , decode = + Decoder + { binary = Hasql.date + , parser = parse Time.day + , delimiter = ',' + } , typeName = "date" } @@ -189,7 +261,12 @@ instance DBType LocalTime where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%FT%T%Q'" - , decode = Hasql.timestamp + , decode = + Decoder + { binary = Hasql.timestamp + , parser = parse Time.localTime + , delimiter = ',' + } , typeName = "timestamp" } @@ -200,7 +277,12 @@ instance DBType TimeOfDay where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%T%Q'" - , decode = Hasql.time + , decode = + Decoder + { binary = Hasql.time + , parser = parse Time.timeOfDay + , delimiter = ',' + } , typeName = "time" } @@ -211,7 +293,12 @@ instance DBType CalendarDiffTime where { encode = Opaleye.ConstExpr . Opaleye.OtherLit . formatTime defaultTimeLocale "'%bmon %0Es'" - , decode = CalendarDiffTime 0 . realToFrac <$> Hasql.interval + , decode = + Decoder + { binary = CalendarDiffTime 0 . realToFrac <$> Hasql.interval + , parser = parse Time.calendarDiffTime + , delimiter = ',' + } , typeName = "interval" } @@ -220,7 +307,12 @@ instance DBType CalendarDiffTime where instance DBType Text where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.StringLit . Text.unpack - , decode = Hasql.text + , decode = + Decoder + { binary = Hasql.text + , parser = pure . Text.decodeUtf8 + , delimiter = ',' + } , typeName = "text" } @@ -249,7 +341,12 @@ instance DBType (CI Lazy.Text) where instance DBType ByteString where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.ByteStringLit - , decode = Hasql.bytea + , decode = + Decoder + { binary = Hasql.bytea + , parser = parse bytestring + , delimiter = ',' + } , typeName = "bytea" } @@ -265,7 +362,14 @@ instance DBType Lazy.ByteString where instance DBType UUID where typeInformation = TypeInformation { encode = Opaleye.ConstExpr . Opaleye.StringLit . UUID.toString - , decode = Hasql.uuid + , decode = + Decoder + { binary = Hasql.uuid + , parser = \input -> case UUID.fromASCIIBytes input of + Just a -> pure a + Nothing -> Left $ "uuid: bad UUID " <> show input + , delimiter = ',' + } , typeName = "uuid" } @@ -277,16 +381,27 @@ instance DBType Value where Opaleye.ConstExpr . Opaleye.OtherLit . Opaleye.quote . Lazy.unpack . Lazy.decodeUtf8 . Aeson.encode - , decode = Hasql.jsonb + , decode = + Decoder + { binary = Hasql.jsonb + , parser = parse Aeson.value + , delimiter = ',' + } , typeName = "jsonb" } + -- | Corresponds to @inet@ -instance DBType (NetAddr IP) where +instance DBType (IP.NetAddr IP.IP) where typeInformation = TypeInformation { encode = - Opaleye.ConstExpr . Opaleye.StringLit . printNetAddr - , decode = Hasql.inet + Opaleye.ConstExpr . Opaleye.StringLit . IP.printNetAddr + , decode = + Decoder + { binary = Hasql.inet + , parser = parse IP.netParser + , delimiter = ',' + } , typeName = "inet" } @@ -297,3 +412,7 @@ instance Sql DBType a => DBType [a] where instance Sql DBType a => DBType (NonEmpty a) where typeInformation = nonEmptyTypeInformation nullable typeInformation + + +floating :: Floating a => A.Parser a -> A.Parser a +floating p = p <|> A.signed (1.0 / 0 <$ "Infinity") <|> 0.0 / 0 <$ "NaN" diff --git a/src/Rel8/Type/Array.hs b/src/Rel8/Type/Array.hs index 0b67d66a..f1254f74 100644 --- a/src/Rel8/Type/Array.hs +++ b/src/Rel8/Type/Array.hs @@ -11,11 +11,20 @@ module Rel8.Type.Array ) where +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + -- base -import Data.Foldable ( toList ) +import Control.Applicative ((<|>), many) +import Data.Bifunctor (first) +import Data.Foldable (fold, toList) import Data.List.NonEmpty ( NonEmpty, nonEmpty ) import Prelude hiding ( null, repeat, zipWith ) +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS + -- hasql import qualified Hasql.Decoders as Hasql @@ -24,7 +33,12 @@ import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye -- rel8 import Rel8.Schema.Null ( Unnullify, Nullity( Null, NotNull ) ) +import Rel8.Type.Decoder (Decoder (..), NullableOrNot (..), Parser) import Rel8.Type.Information ( TypeInformation(..), parseTypeInformation ) +import Rel8.Type.Parser (parse) + +-- text +import qualified Data.Text as Text array :: Foldable f @@ -41,11 +55,16 @@ listTypeInformation :: () -> TypeInformation [a] listTypeInformation nullity info@TypeInformation {encode, decode} = TypeInformation - { decode = case nullity of - Null -> - Hasql.listArray (decodeArrayElement info (Hasql.nullable decode)) - NotNull -> - Hasql.listArray (decodeArrayElement info (Hasql.nonNullable decode)) + { decode = + Decoder + { binary = Hasql.listArray $ case nullity of + Null -> Hasql.nullable (decodeArrayElement info decode) + NotNull -> Hasql.nonNullable (decodeArrayElement info decode) + , parser = case nullity of + Null -> arrayParser (Nullable decode) + NotNull -> arrayParser (NonNullable decode) + , delimiter = ',' + } , encode = case nullity of Null -> Opaleye.ArrayExpr . @@ -64,9 +83,9 @@ nonEmptyTypeInformation :: () -> TypeInformation (Unnullify a) -> TypeInformation (NonEmpty a) nonEmptyTypeInformation nullity = - parseTypeInformation parse toList . listTypeInformation nullity + parseTypeInformation fromList toList . listTypeInformation nullity where - parse = maybe (Left message) Right . nonEmpty + fromList = maybe (Left message) Right . nonEmpty message = "failed to decode NonEmptyList: got empty list" @@ -78,23 +97,54 @@ isArray = \case arrayType :: TypeInformation a -> String arrayType info - | isArray info = "record" + | isArray info = "text" | otherwise = typeName info -decodeArrayElement :: TypeInformation a -> Hasql.NullableOrNot Hasql.Value x -> Hasql.NullableOrNot Hasql.Value x +decodeArrayElement :: TypeInformation a -> Decoder x -> Hasql.Value x decodeArrayElement info - | isArray info = Hasql.nonNullable . Hasql.composite . Hasql.field - | otherwise = id + | isArray info = \decoder -> + Hasql.refine (first Text.pack . parser decoder) Hasql.bytea + | otherwise = binary encodeArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr encodeArrayElement info - | isArray info = Opaleye.UnExpr (Opaleye.UnOpOther "ROW") + | isArray info = Opaleye.CastExpr "text" | otherwise = id extractArrayElement :: TypeInformation a -> Opaleye.PrimExpr -> Opaleye.PrimExpr extractArrayElement info - | isArray info = flip Opaleye.CompositeExpr "f1" + | isArray info = Opaleye.CastExpr (typeName info <> "[]") | otherwise = id + + +parseArray :: Char -> ByteString -> Either String [Maybe ByteString] +parseArray delimiter = parse $ do + A.char '{' *> A.sepBy element (A.char delimiter) <* A.char '}' + where + element = null <|> nonNull + where + null = Nothing <$ A.string "NULL" + nonNull = Just <$> (quoted <|> unquoted) + where + unquoted = A.takeWhile1 (A.notInClass (delimiter : "\"{}")) + quoted = A.char '"' *> contents <* A.char '"' + where + contents = fold <$> many (unquote <|> unescape) + where + unquote = A.takeWhile1 (A.notInClass "\"\\") + unescape = A.char '\\' *> do + BS.singleton <$> do + A.char '\\' <|> A.char '"' + + +arrayParser :: NullableOrNot Decoder a -> Parser [a] +arrayParser = \case + Nullable Decoder {parser, delimiter} -> \input -> do + elements <- parseArray delimiter input + traverse (traverse parser) elements + NonNullable Decoder {parser, delimiter} -> \input -> do + elements <- parseArray delimiter input + traverse (maybe (Left "array: unexpected null") parser) elements diff --git a/src/Rel8/Type/Composite.hs b/src/Rel8/Type/Composite.hs index 3482e0b1..eaecb82c 100644 --- a/src/Rel8/Type/Composite.hs +++ b/src/Rel8/Type/Composite.hs @@ -1,9 +1,11 @@ {-# language AllowAmbiguousTypes #-} {-# language BlockArguments #-} {-# language DataKinds #-} +{-# language DisambiguateRecordFields #-} {-# language FlexibleContexts #-} {-# language GADTs #-} {-# language NamedFieldPuns #-} +{-# language OverloadedStrings #-} {-# language ScopedTypeVariables #-} {-# language StandaloneKindSignatures #-} {-# language TypeApplications #-} @@ -18,12 +20,22 @@ module Rel8.Type.Composite ) where +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + -- base -import Data.Functor.Const ( Const( Const ), getConst ) -import Data.Functor.Identity ( Identity( Identity ) ) +import Control.Applicative ((<|>), many, optional) +import Data.Foldable (fold) +import Data.Functor.Const (Const (Const), getConst) +import Data.Functor.Identity (Identity (Identity)) import Data.Kind ( Constraint, Type ) +import Data.List (uncons) import Prelude +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS + -- hasql import qualified Hasql.Decoders as Hasql @@ -45,13 +57,20 @@ import Rel8.Table.Ord ( OrdTable ) import Rel8.Table.Rel8able () import Rel8.Table.Serialize ( litHTable ) import Rel8.Type ( DBType, typeInformation ) +import Rel8.Type.Decoder (Decoder (Decoder), Parser) +import qualified Rel8.Type.Decoder as Decoder import Rel8.Type.Eq ( DBEq ) import Rel8.Type.Information ( TypeInformation(..) ) import Rel8.Type.Ord ( DBOrd, DBMax, DBMin ) +import Rel8.Type.Parser (parse) -- semigroupoids import Data.Functor.Apply ( WrappedApplicative(..) ) +-- transformers +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (StateT (StateT), runStateT) + -- | A deriving-via helper type for column types that store a Haskell product -- type in a single Postgres column using a Postgres composite type. @@ -68,7 +87,12 @@ newtype Composite a = Composite instance DBComposite a => DBType (Composite a) where typeInformation = TypeInformation - { decode = Hasql.composite (Composite . fromResult @_ @(HKD a Expr) <$> decoder) + { decode = + Decoder + { binary = Hasql.composite (Composite . fromResult @_ @(HKD a Expr) <$> decoder) + , parser = fmap (Composite . fromResult @_ @(HKD a Expr)) . parser + , delimiter = ',' + } , encode = encoder . litHTable . toResult @_ @(HKD a Expr) . unComposite , typeName = compositeTypeName @a } @@ -124,8 +148,8 @@ decoder = unwrapApplicative $ htabulateA \field -> case hfield hspecs field of Spec {nullity, info} -> WrapApplicative $ Identity <$> case nullity of - Null -> Hasql.field $ Hasql.nullable $ decode info - NotNull -> Hasql.field $ Hasql.nonNullable $ decode info + Null -> Hasql.field $ Hasql.nullable $ Decoder.binary $ decode info + NotNull -> Hasql.field $ Hasql.nonNullable $ Decoder.binary $ decode info encoder :: HTable t => t Expr -> Opaleye.PrimExpr @@ -133,3 +157,39 @@ encoder a = Opaleye.FunExpr "ROW" exprs where exprs = getConst $ htabulateA \field -> case hfield a field of expr -> Const [toPrimExpr expr] + + +parser :: HTable t => Parser (t Result) +parser input = do + fields <- parseRow input + (a, rest) <- runStateT go fields + case rest of + [] -> pure a + _ -> Left "composite: too many fields" + where + go = htabulateA \field -> do + mbytes <- StateT $ maybe missing pure . uncons + lift $ Identity <$> case hfield hspecs field of + Spec {nullity, info} -> case nullity of + Null -> traverse (Decoder.parser (decode info)) mbytes + NotNull -> case mbytes of + Nothing -> Left "composite: unexpected null" + Just bytes -> Decoder.parser (decode info) bytes + missing = Left "composite: missing fields" + + +parseRow :: ByteString -> Either String [Maybe ByteString] +parseRow = parse $ do + A.char '(' *> A.sepBy element (A.char ',') <* A.char ')' + where + element = optional (quoted <|> unquoted) + where + unquoted = A.takeWhile1 (A.notInClass ",\"()") + quoted = A.char '"' *> contents <* A.char '"' + where + contents = fold <$> many (unquote <|> unescape) + where + unquote = A.takeWhile1 (A.notInClass "\"\\") + unescape = A.char '\\' *> do + BS.singleton <$> do + A.char '\\' <|> A.char '"' diff --git a/src/Rel8/Type/Decoder.hs b/src/Rel8/Type/Decoder.hs new file mode 100644 index 00000000..5322e7c5 --- /dev/null +++ b/src/Rel8/Type/Decoder.hs @@ -0,0 +1,64 @@ +{-# language DerivingStrategies #-} +{-# language DeriveFunctor #-} +{-# language GADTs #-} +{-# language NamedFieldPuns #-} +{-# language StandaloneKindSignatures #-} + +module Rel8.Type.Decoder ( + Decoder (..), + NullableOrNot (..), + Parser, + parseDecoder, +) where + +-- base +import Control.Monad ((>=>)) +import Data.Bifunctor (first) +import Data.Kind (Type) +import Prelude + +-- bytestring +import Data.ByteString (ByteString) + +-- hasql +import qualified Hasql.Decoders as Hasql + +-- text +import qualified Data.Text as Text + + +type Parser :: Type -> Type +type Parser a = ByteString -> Either String a + + +type Decoder :: Type -> Type +data Decoder a = Decoder + { binary :: Hasql.Value a + -- ^ How to deserialize from PostgreSQL's binary format. + , parser :: Parser a + -- ^ How to deserialize from PostgreSQL's text format. + , delimiter :: Char + -- ^ The delimiter that is used in PostgreSQL's text format in arrays of + -- this type (this is almost always ','). + } + deriving stock (Functor) + + +-- | Apply a parser to 'Decoder'. +-- +-- This can be used if the data stored in the database should only be subset of +-- a given 'Decoder'. The parser is applied when deserializing rows +-- returned. +parseDecoder :: (a -> Either String b) -> Decoder a -> Decoder b +parseDecoder f Decoder {binary, parser, delimiter} = + Decoder + { binary = Hasql.refine (first Text.pack . f) binary + , parser = parser >=> f + , delimiter + } + + +type NullableOrNot :: (Type -> Type) -> Type -> Type +data NullableOrNot decoder a where + NonNullable :: decoder a -> NullableOrNot decoder a + Nullable :: decoder a -> NullableOrNot decoder (Maybe a) diff --git a/src/Rel8/Type/Enum.hs b/src/Rel8/Type/Enum.hs index 3324079e..27a67f10 100644 --- a/src/Rel8/Type/Enum.hs +++ b/src/Rel8/Type/Enum.hs @@ -39,12 +39,14 @@ import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye -- rel8 import Rel8.Type ( DBType, typeInformation ) +import Rel8.Type.Decoder (Decoder (..)) import Rel8.Type.Eq ( DBEq ) import Rel8.Type.Information ( TypeInformation(..) ) import Rel8.Type.Ord ( DBOrd, DBMax, DBMin ) -- text -import Data.Text ( pack ) +import Data.Text (pack) +import Data.Text.Encoding (decodeUtf8) -- | A deriving-via helper type for column types that store an \"enum\" type @@ -66,10 +68,15 @@ newtype Enum a = Enum instance DBEnum a => DBType (Enum a) where typeInformation = TypeInformation { decode = - Hasql.enum $ - flip lookup $ - map ((pack . enumValue &&& Enum) . to) $ - genumerate @(Rep a) + let + mapping = (pack . enumValue &&& Enum) . to <$> genumerate @(Rep a) + unrecognised = Left "enum: unrecognised value" + in + Decoder + { binary = Hasql.enum (`lookup` mapping) + , parser = maybe unrecognised pure . (`lookup` mapping) . decodeUtf8 + , delimiter = ',' + } , encode = Opaleye.ConstExpr . Opaleye.StringLit . diff --git a/src/Rel8/Type/Information.hs b/src/Rel8/Type/Information.hs index 47651677..e378d1b3 100644 --- a/src/Rel8/Type/Information.hs +++ b/src/Rel8/Type/Information.hs @@ -10,18 +10,14 @@ module Rel8.Type.Information where -- base -import Data.Bifunctor ( first ) import Data.Kind ( Type ) import Prelude --- hasql -import qualified Hasql.Decoders as Hasql - -- opaleye import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye --- text -import qualified Data.Text as Text +-- rel8 +import Rel8.Type.Decoder (Decoder, parseDecoder) -- | @TypeInformation@ describes how to encode and decode a Haskell type to and @@ -31,7 +27,7 @@ type TypeInformation :: Type -> Type data TypeInformation a = TypeInformation { encode :: a -> Opaleye.PrimExpr -- ^ How to encode a single Haskell value as a SQL expression. - , decode :: Hasql.Value a + , decode :: Decoder a -- ^ How to deserialize a single result back to Haskell. , typeName :: String -- ^ The name of the SQL type. @@ -62,6 +58,6 @@ parseTypeInformation :: () parseTypeInformation to from TypeInformation {encode, decode, typeName} = TypeInformation { encode = encode . from - , decode = Hasql.refine (first Text.pack . to) decode + , decode = parseDecoder to decode , typeName } diff --git a/src/Rel8/Type/JSONBEncoded.hs b/src/Rel8/Type/JSONBEncoded.hs index 7530f0d5..dd776a3e 100644 --- a/src/Rel8/Type/JSONBEncoded.hs +++ b/src/Rel8/Type/JSONBEncoded.hs @@ -1,10 +1,13 @@ {-# language StandaloneKindSignatures #-} -module Rel8.Type.JSONBEncoded ( JSONBEncoded(..) ) where +module Rel8.Type.JSONBEncoded + ( JSONBEncoded(..) + ) +where -- aeson -import Data.Aeson ( FromJSON, ToJSON, parseJSON, toJSON ) -import Data.Aeson.Types ( parseEither ) +import Data.Aeson (FromJSON, ToJSON, eitherDecodeStrict, parseJSON, toJSON) +import Data.Aeson.Types (parseEither) -- base import Data.Bifunctor ( first ) @@ -16,6 +19,7 @@ import qualified Hasql.Decoders as Hasql -- rel8 import Rel8.Type ( DBType(..) ) +import Rel8.Type.Decoder (Decoder (..)) import Rel8.Type.Information ( TypeInformation(..) ) -- text @@ -30,6 +34,11 @@ newtype JSONBEncoded a = JSONBEncoded { fromJSONBEncoded :: a } instance (FromJSON a, ToJSON a) => DBType (JSONBEncoded a) where typeInformation = TypeInformation { encode = encode typeInformation . toJSON . fromJSONBEncoded - , decode = Hasql.refine (first pack . fmap JSONBEncoded . parseEither parseJSON) Hasql.jsonb + , decode = + Decoder + { binary = Hasql.refine (first pack . fmap JSONBEncoded . parseEither parseJSON) Hasql.jsonb + , parser = fmap JSONBEncoded . eitherDecodeStrict + , delimiter = ',' + } , typeName = "jsonb" } diff --git a/src/Rel8/Type/Parser.hs b/src/Rel8/Type/Parser.hs new file mode 100644 index 00000000..3b6fbb12 --- /dev/null +++ b/src/Rel8/Type/Parser.hs @@ -0,0 +1,26 @@ +module Rel8.Type.Parser + ( parse + ) +where + +-- attoparsec +import qualified Data.Attoparsec.ByteString as A + +-- base +import Control.Applicative ((<|>)) +import Control.Monad (unless) +import Prelude + +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS + +-- utf8-string +import qualified Data.ByteString.UTF8 as UTF8 + + +parse :: A.Parser a -> ByteString -> Either String a +parse parser = do + A.parseOnly (parser <* A.endOfInput <|> debug) + where + debug = A.takeByteString >>= fail . UTF8.toString diff --git a/src/Rel8/Type/Parser/ByteString.hs b/src/Rel8/Type/Parser/ByteString.hs new file mode 100644 index 00000000..1f2b4f64 --- /dev/null +++ b/src/Rel8/Type/Parser/ByteString.hs @@ -0,0 +1,54 @@ +{-# language OverloadedStrings #-} +{-# language TypeApplications #-} + +module Rel8.Type.Parser.ByteString + ( bytestring + ) +where + +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + +-- base +import Control.Applicative ((<|>), many) +import Control.Monad (guard) +import Data.Bits ((.|.), shiftL) +import Data.Char (isOctDigit) +import Data.Foldable (fold) +import Prelude + +-- base16 +import Data.ByteString.Base16 (decodeBase16) + +-- bytestring +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS + +-- text +import qualified Data.Text as Text + + +bytestring :: A.Parser ByteString +bytestring = hex <|> escape + where + hex = do + digits <- "\\x" *> A.takeByteString + either (fail . Text.unpack) pure $ decodeBase16 digits + escape = fold <$> many (escaped <|> unescaped) + where + unescaped = A.takeWhile1 (/= '\\') + escaped = BS.singleton <$> (backslash <|> octal) + where + backslash = '\\' <$ "\\\\" + octal = do + a <- A.char '\\' *> digit + b <- digit + c <- digit + let + result = a `shiftL` 6 .|. b `shiftL` 3 .|. c + guard $ result < 0o400 + pure $ toEnum result + where + digit = do + c <- A.satisfy isOctDigit + pure $ fromEnum c - fromEnum '0' diff --git a/src/Rel8/Type/Parser/Time.hs b/src/Rel8/Type/Parser/Time.hs new file mode 100644 index 00000000..abc2cb67 --- /dev/null +++ b/src/Rel8/Type/Parser/Time.hs @@ -0,0 +1,156 @@ +{-# language OverloadedStrings #-} +{-# language TypeApplications #-} + +module Rel8.Type.Parser.Time + ( calendarDiffTime + , day + , localTime + , timeOfDay + , utcTime + ) +where + +-- attoparsec +import qualified Data.Attoparsec.ByteString.Char8 as A + +-- base +import Control.Applicative ((<|>), optional) +import Data.Bits ((.&.)) +import Data.Bool (bool) +import Data.Fixed (Fixed (MkFixed), Pico, divMod') +import Data.Functor (void) +import Data.Int (Int64) +import Prelude + +-- bytestring +import qualified Data.ByteString as BS + +-- time +import Data.Time.Calendar (Day, addDays, fromGregorianValid) +import Data.Time.Clock (DiffTime, UTCTime (UTCTime)) +import Data.Time.Format.ISO8601 (iso8601ParseM) +import Data.Time.LocalTime + ( CalendarDiffTime (CalendarDiffTime) + , LocalTime (LocalTime) + , TimeOfDay (TimeOfDay) + , sinceMidnight + ) + +-- utf8 +import qualified Data.ByteString.UTF8 as UTF8 + + +day :: A.Parser Day +day = do + y <- A.decimal <* A.char '-' + m <- twoDigits <* A.char '-' + d <- twoDigits + maybe (fail "Day: invalid date") pure $ fromGregorianValid y m d + + +timeOfDay :: A.Parser TimeOfDay +timeOfDay = do + h <- twoDigits + m <- A.char ':' *> twoDigits + s <- A.char ':' *> secondsParser + if h < 24 && m < 60 && s <= 60 + then pure $ TimeOfDay h m s + else fail "TimeOfDay: invalid time" + + +localTime :: A.Parser LocalTime +localTime = LocalTime <$> day <* separator <*> timeOfDay + where + separator = A.char ' ' <|> A.char 'T' + + +utcTime :: A.Parser UTCTime +utcTime = do + LocalTime date time <- localTime + tz <- timeZone + let + (days, time') = (sinceMidnight time + tz) `divMod'` oneDay + where + oneDay = 24 * 60 * 60 + date' = addDays days date + pure $ UTCTime date' time' + + +calendarDiffTime :: A.Parser CalendarDiffTime +calendarDiffTime = iso8601 <|> postgres + where + iso8601 = A.takeByteString >>= iso8601ParseM . UTF8.toString + at = optional (A.char '@') *> A.skipSpace + plural unit = A.skipSpace <* (unit <* optional "s") <* A.skipSpace + parseMonths = sql <|> postgresql + where + sql = A.signed $ do + years <- A.decimal <* A.char '-' + months <- A.decimal <* A.skipSpace + pure $ years * 12 + months + postgresql = do + at + years <- A.signed A.decimal <* plural "year" <|> pure 0 + months <- A.signed A.decimal <* plural "mon" <|> pure 0 + pure $ years * 12 + months + parseTime = (+) <$> parseDays <*> time + where + time = realToFrac <$> (sql <|> postgresql) + where + sql = A.signed $ do + h <- A.signed A.decimal <* A.char ':' + m <- twoDigits <* A.char ':' + s <- secondsParser + pure $ fromIntegral (((h * 60) + m) * 60) + s + postgresql = do + h <- A.signed A.decimal <* plural "hour" <|> pure 0 + m <- A.signed A.decimal <* plural "min" <|> pure 0 + s <- secondsParser <* plural "sec" <|> pure 0 + pure $ fromIntegral @Int (((h * 60) + m) * 60) + s + parseDays = do + days <- A.signed A.decimal <* (plural "days" <|> skipSpace1) <|> pure 0 + pure $ fromIntegral @Int days * 24 * 60 * 60 + postgres = do + months <- parseMonths + time <- parseTime + ago <- (True <$ (A.skipSpace *> "ago")) <|> pure False + pure $ CalendarDiffTime (bool id negate ago months) (bool id negate ago time) + + +secondsParser :: A.Parser Pico +secondsParser = do + integral <- twoDigits + mfractional <- optional (A.char '.' *> A.takeWhile1 A.isDigit) + pure $ case mfractional of + Nothing -> fromIntegral integral + Just fractional -> parseFraction (fromIntegral integral) fractional + where + parseFraction integral digits = MkFixed (fromIntegral (n * 10 ^ e)) + where + e = max 0 (12 - BS.length digits) + n = BS.foldl' go (integral :: Int64) (BS.take 12 digits) + where + go acc digit = 10 * acc + fromIntegral (fromEnum digit .&. 0xf) + + +twoDigits :: A.Parser Int +twoDigits = do + u <- A.digit + l <- A.digit + pure $ fromEnum u .&. 0xf * 10 + fromEnum l .&. 0xf + + +timeZone :: A.Parser DiffTime +timeZone = 0 <$ A.char 'Z' <|> diffTime + + +diffTime :: A.Parser DiffTime +diffTime = A.signed $ do + h <- twoDigits + m <- A.char ':' *> twoDigits <|> pure 0 + s <- A.char ':' *> secondsParser <|> pure 0 + pure $ sinceMidnight $ TimeOfDay h m s + + +skipSpace1 :: A.Parser () +skipSpace1 = void $ A.takeWhile1 A.isSpace diff --git a/tests/Main.hs b/tests/Main.hs index e4b43491..85a3e994 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -2,7 +2,7 @@ {-# language BlockArguments #-} {-# language DeriveAnyClass #-} {-# language DeriveGeneric #-} -{-# language DerivingStrategies #-} +{-# language DerivingVia #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} {-# language MonoLocalBinds #-} @@ -23,6 +23,7 @@ import Control.Applicative ( empty, liftA2, liftA3 ) import Control.Exception ( bracket, throwIO ) import Control.Monad ( (>=>), void ) import Data.Bifunctor ( bimap ) +import Data.Fixed (Fixed (MkFixed)) import Data.Foldable ( for_ ) import Data.Int ( Int32, Int64 ) import Data.List ( nub, sort ) @@ -57,6 +58,10 @@ import qualified Hedgehog.Range as Range -- mmorph import Control.Monad.Morph ( hoist ) +-- network-ip +import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) +import Data.DoubleWord (Word128(..)) + -- rel8 import Rel8 ( Result ) import qualified Rel8 @@ -87,9 +92,6 @@ import qualified Database.Postgres.Temp as TmpPostgres -- uuid import qualified Data.UUID --- ip -import Network.IP.Addr (NetAddr, IP, IP4(..), IP6(..), IP46(..), net4Addr, net6Addr, fromNetAddr46, Net4Addr, Net6Addr) -import Data.DoubleWord (Word128(..)) main :: IO () main = defaultMain tests @@ -143,6 +145,7 @@ tests = sql "CREATE TABLE test_table ( column1 text not null, column2 bool not null )" sql "CREATE TABLE unique_table ( \"key\" text not null unique, \"value\" text not null )" sql "CREATE SEQUENCE test_seq" + sql "CREATE TYPE composite AS (\"bool\" bool, \"char\" char, \"array\" int4[])" return db @@ -409,21 +412,37 @@ testAp = databasePropertyTest "Cartesian product (<*>)" \transaction -> do sort result === sort (liftA2 (,) rows1 rows2) +data Composite = Composite + { bool :: !Bool + , char :: !Char + , array :: ![Int32] + } + deriving (Eq, Show, Generic) + deriving (Rel8.DBType) via Rel8.Composite Composite + + +instance Rel8.DBComposite Composite where + compositeTypeName = "composite" + compositeFields = Rel8.namesFromLabels + + testDBType :: IO TmpPostgres.DB -> TestTree testDBType getTestDatabase = testGroup "DBType instances" [ dbTypeTest "Bool" Gen.bool , dbTypeTest "ByteString" $ Gen.bytes (Range.linear 0 128) + , dbTypeTest "CalendarDiffTime" genCalendarDiffTime , dbTypeTest "CI Lazy Text" $ mk . Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "CI Text" $ mk <$> Gen.text (Range.linear 0 10) Gen.unicode + , dbTypeTest "Composite" genComposite , dbTypeTest "Day" genDay - , dbTypeTest "Double" $ (/10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) - , dbTypeTest "Float" $ (/10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Double" $ (/ 10) . fromIntegral @Int @Double <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Float" $ (/ 10) . fromIntegral @Int @Float <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Int32" $ Gen.integral @_ @Int32 Range.linearBounded , dbTypeTest "Int64" $ Gen.integral @_ @Int64 Range.linearBounded , dbTypeTest "Lazy ByteString" $ Data.ByteString.Lazy.fromStrict <$> Gen.bytes (Range.linear 0 128) , dbTypeTest "Lazy Text" $ Data.Text.Lazy.fromStrict <$> Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "LocalTime" genLocalTime - , dbTypeTest "Scientific" $ (/10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) + , dbTypeTest "Scientific" $ (/ 10) . fromIntegral @Int @Scientific <$> Gen.integral (Range.linear (-100) 100) , dbTypeTest "Text" $ Gen.text (Range.linear 0 10) Gen.unicode , dbTypeTest "TimeOfDay" genTimeOfDay , dbTypeTest "UTCTime" $ UTCTime <$> genDay <*> genDiffTime @@ -432,24 +451,51 @@ testDBType getTestDatabase = testGroup "DBType instances" ] where - dbTypeTest :: (Eq a, Show a, Rel8.DBType a) => TestName -> Gen a -> TestTree + dbTypeTest :: (Eq a, Show a, Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => TestName -> Gen a -> TestTree dbTypeTest name generator = testGroup name [ databasePropertyTest name (t generator) getTestDatabase , databasePropertyTest ("Maybe " <> name) (t (Gen.maybe generator)) getTestDatabase ] - t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a) + t :: forall a b. (Eq a, Show a, Rel8.Sql Rel8.DBType a, Rel8.ToExprs (Rel8.Expr a) a) => Gen a -> (TestT Transaction () -> PropertyT IO b) -> PropertyT IO b t generator transaction = do x <- forAll generator + y <- forAll generator transaction do [res] <- lift do statement () $ Rel8.select do pure (Rel8.litExpr x) diff res (==) x + [res'] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res' (==) [[x, y]] + [res3] <- lift do + statement () $ Rel8.select $ Rel8.many $ Rel8.many $ Rel8.many do + Rel8.values [Rel8.litExpr x, Rel8.litExpr y] + diff res3 (==) [[[x, y]]] + res'' <- lift do + statement () $ Rel8.select do + xs <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]) + Rel8.catListTable xs + diff res'' (==) [x, y] + res''' <- lift do + statement () $ Rel8.select do + xss <- Rel8.catListTable (Rel8.listTable [Rel8.listTable [Rel8.listTable [Rel8.litExpr x, Rel8.litExpr y]]]) + xs <- Rel8.catListTable xss + Rel8.catListTable xs + diff res''' (==) [x, y] + + genComposite :: Gen Composite + genComposite = do + bool <- Gen.bool + char <- Gen.unicode + array <- Gen.list (Range.linear 0 10) (Gen.int32 (Range.linear (-10000) 10000)) + pure Composite {..} genDay :: Gen Day genDay = do @@ -458,6 +504,14 @@ testDBType getTestDatabase = testGroup "DBType instances" day <- Gen.integral (Range.linear 1 31) Gen.just $ pure $ fromGregorianValid year month day + genCalendarDiffTime :: Gen CalendarDiffTime + genCalendarDiffTime = do + -- hardcoded to 0 because Hasql's 'interval' decoder needs to return a + -- CalendarDiffTime for this to be properly round-trippable + months <- pure 0 -- Gen.integral (Range.linear 0 120) + diffTime <- secondsToNominalDiffTime . MkFixed . (* 1000000) <$> Gen.integral (Range.linear 0 2147483647999999) + pure $ CalendarDiffTime months diffTime + genDiffTime :: Gen DiffTime genDiffTime = secondsToDiffTime <$> Gen.integral (Range.linear 0 86401)