diff --git a/src/Database/PostgreSQL/Simple.hs b/src/Database/PostgreSQL/Simple.hs index 6ecc86d..9c066ae 100644 --- a/src/Database/PostgreSQL/Simple.hs +++ b/src/Database/PostgreSQL/Simple.hs @@ -103,6 +103,9 @@ module Database.PostgreSQL.Simple , forEachWith , forEachWith_ , returningWith + -- ** Streaming with single row mode + , foldSingleRowModeWith + , foldSingleRowModeWith_ -- * Statements that do not return results , execute , execute_ @@ -123,6 +126,7 @@ module Database.PostgreSQL.Simple import Data.ByteString.Builder (Builder, byteString, char8) import Control.Applicative ((<$>)) import Control.Exception as E +import Control.Monad (unless) import Data.ByteString (ByteString) import Data.Int (Int64) import Data.List (intersperse) @@ -508,6 +512,57 @@ foldWithOptionsAndParser opts parser conn template qs a f = do q <- formatQuery conn template qs doFold opts parser conn template (Query q) a f +-- | Perform a @SELECT@ or other SQL query that is expected to return +-- results. Results are streamed incrementally from the server, and +-- consumed via a left fold. +-- +-- This fold is /not/ strict. The stream consumer is responsible for +-- forcing the evaluation of its result to avoid space leaks. +-- +-- Unlike 'fold' and friends, this is implemented using +-- +-- instead of a cursor. +-- You cannot execute other queries while streaming is in progress. +foldSingleRowModeWith :: (ToRow params) => RowParser row -> Connection -> Query -> params -> a -> (a -> row -> IO a) -> IO a +foldSingleRowModeWith parser conn template qs a0 f = do + q <- formatQuery conn template qs + doFoldSingleRow parser conn q a0 f + +-- | A version of 'foldSingleRowModeWith' that does not perform query substitution. +foldSingleRowModeWith_ :: RowParser row -> Connection -> Query -> a -> (a -> row -> IO a) -> IO a +foldSingleRowModeWith_ parser conn (Query q) a0 f = + doFoldSingleRow parser conn q a0 f + +doFoldSingleRow :: RowParser row -> Connection -> ByteString -> a -> (a -> row -> IO a) -> IO a +doFoldSingleRow parser conn q a0 f = do + queryOk <- withConnection conn $ \h -> PQ.sendQuery h q + unless queryOk $ do + mmsg <- withConnection conn PQ.errorMessage + throwIO $ QueryError (maybe "" B.unpack mmsg) (Query q) + srmOk <- withConnection conn PQ.setSingleRowMode + unless srmOk $ + throwIO $ fatalError "could not activate single row mode" + loop a0 `finally` withConnection conn consumeResults + where + loop a = do + mresult <- withConnection conn PQ.getResult + case mresult of + Nothing -> pure a + Just result -> do + status <- PQ.resultStatus result + case status of + PQ.SingleTuple -> do + ncols <- PQ.nfields result + row <- getRowWith parser 0 ncols conn result + a' <- f a row + loop a' + PQ.TuplesOk -> do + nrows <- PQ.ntuples result + if nrows == 0 + then pure a + else throwResultError "doFoldSingleRow" result status + _ -> throwResultError "doFoldSingleRow" result status + -- | A version of 'fold' that does not perform query substitution. fold_ :: (FromRow r) => Connection diff --git a/src/Database/PostgreSQL/Simple/Copy.hs b/src/Database/PostgreSQL/Simple/Copy.hs index 4662b8e..1f852ad 100644 --- a/src/Database/PostgreSQL/Simple/Copy.hs +++ b/src/Database/PostgreSQL/Simple/Copy.hs @@ -261,9 +261,3 @@ getCopyCommandTag funcName pqconn = do errCmdStatusFmt = B.unpack funcName ++ ": failed to parse command status" -consumeResults :: PQ.Connection -> IO () -consumeResults pqconn = do - mres <- PQ.getResult pqconn - case mres of - Nothing -> return () - Just _ -> consumeResults pqconn diff --git a/src/Database/PostgreSQL/Simple/FromField.hs b/src/Database/PostgreSQL/Simple/FromField.hs index feab6c7..f7387f7 100644 --- a/src/Database/PostgreSQL/Simple/FromField.hs +++ b/src/Database/PostgreSQL/Simple/FromField.hs @@ -221,7 +221,11 @@ class FromField a where -- finally query the database's meta-schema. typename :: Field -> Conversion ByteString -typename field = typname <$> typeInfo field +typename field = Conversion $ \conn -> do + status <- PQ.resultStatus (result field) + case status of + PQ.SingleTuple -> pure (Ok "unknown type") + _ -> runConversion (typname <$> typeInfo field) conn typeInfo :: Field -> Conversion TypeInfo typeInfo Field{..} = Conversion $ \conn -> do diff --git a/src/Database/PostgreSQL/Simple/Internal.hs b/src/Database/PostgreSQL/Simple/Internal.hs index b7adad4..91095b8 100644 --- a/src/Database/PostgreSQL/Simple/Internal.hs +++ b/src/Database/PostgreSQL/Simple/Internal.hs @@ -635,3 +635,10 @@ breakOnSingleQuestionMark b = go (B8.empty, b) go2 ('?', t2) = go (noQ `B8.snoc` '?',t2) -- Anything else means go2 _ = tup + +consumeResults :: PQ.Connection -> IO () +consumeResults pqconn = do + mres <- PQ.getResult pqconn + case mres of + Nothing -> return () + Just _ -> consumeResults pqconn diff --git a/test/Main.hs b/test/Main.hs index 32eb230..4f92f20 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -14,6 +14,7 @@ import Common import Database.PostgreSQL.Simple.Copy import Database.PostgreSQL.Simple.ToField (ToField) import Database.PostgreSQL.Simple.FromField (FromField) +import Database.PostgreSQL.Simple.FromRow (FromRow(..)) import Database.PostgreSQL.Simple.HStore import Database.PostgreSQL.Simple.Newtypes import Database.PostgreSQL.Simple.Internal (breakOnSingleQuestionMark) @@ -63,6 +64,7 @@ tests env = testGroup "tests" [ testBytea , testCase "ExecuteMany" . testExecuteMany , testCase "Fold" . testFold + , testCase "FoldSingleRow" . testFoldSingleRow , testCase "Notify" . testNotify , testCase "Serializable" . testSerializable , testCase "Time" . testTime @@ -185,6 +187,53 @@ testFold TestEnv{..} = do return () +testFoldSingleRow :: TestEnv -> Assertion +testFoldSingleRow TestEnv{..} = do + xs <- foldSingleRowModeWith_ fromRow conn "SELECT 1 WHERE FALSE" + [] $ \xs (Only x) -> return (x:xs) + xs @?= ([] :: [Int]) + + xs <- foldSingleRowModeWith_ fromRow conn "SELECT generate_series(1,10000)" + [] $ \xs (Only x) -> return (x:xs) + reverse xs @?= ([1..10000] :: [Int]) + + ref <- newIORef [] + foldSingleRowModeWith fromRow conn "SELECT * FROM generate_series(1,?) a, generate_series(1,?) b" + (100 :: Int, 50 :: Int) () $ \() (a :: Int, b :: Int)-> do + xs <- readIORef ref + writeIORef ref $! (a,b):xs + xs <- readIORef ref + reverse xs @?= [(a,b) | b <- [1..50], a <- [1..100]] + + -- Make sure it propagates our exception + ref <- newIORef [] + True <- expectError (== TestException) $ + foldSingleRowModeWith_ fromRow conn "SELECT generate_series(1,10)" () $ \() (Only a) -> + if a == 5 then do + throwIO TestException + else do + xs <- readIORef ref + writeIORef ref $! (a :: Int) : xs + xs <- readIORef ref + reverse xs @?= [1..4] + -- and didn't leave the connection in a bad state. + xs <- foldSingleRowModeWith_ fromRow conn "SELECT 1" + [] $ \xs (Only x) -> return (x:xs) + xs @?= ([1] :: [Int]) + + -- When in single row mode, we cannot make additional queries while + -- handling errors. We still want to emit vaguely sensible errors when + -- given the wrong parser, not "another command is already in progress". + execute_ conn "DROP TYPE IF EXISTS foo; CREATE TYPE foo AS ENUM ('foo', 'bar');" + expectError (\e -> case e of + Incompatible { errSQLField = "foo" + , errMessage = "types incompatible" } -> True + _ -> False) $ + foldSingleRowModeWith_ fromRow conn "SELECT 'foo'::foo" () $ + \() (Only x) -> print (x :: Int) + + return () + queryFailure :: forall a. (FromField a, Typeable a, Show a) => Connection -> Query -> a -> Assertion queryFailure conn q resultType = do