Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #69 #71

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions postgresql-simple.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ test-suite test

build-depends:
aeson
, async
, base
, base16-bytestring
, bytestring
Expand Down
6 changes: 3 additions & 3 deletions src/Database/PostgreSQL/Simple/Copy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ doCopy funcName conn template q = do
#if MIN_VERSION_postgresql_libpq(0,9,2)
PQ.SingleTuple -> errMsg "single-row mode is not supported"
#endif
PQ.BadResponse -> throwResultError funcName result status
PQ.NonfatalError -> throwResultError funcName result status
PQ.FatalError -> throwResultError funcName result status
PQ.BadResponse -> throwResultError funcName conn result status
PQ.NonfatalError -> throwResultError funcName conn result status
PQ.FatalError -> throwResultError funcName conn result status

data CopyOutResult
= CopyOutRow !B.ByteString -- ^ Data representing either exactly
Expand Down
2 changes: 1 addition & 1 deletion src/Database/PostgreSQL/Simple/Cursor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ foldForwardWithParser (Cursor name conn) parser chunkSize f a0 = do
Right <$> foldM' inner a0 0 (nrows - 1)
else
return $ Left a0
_ -> throwResultError "foldForwardWithParser" result status
_ -> throwResultError "foldForwardWithParser" conn result status

-- | Fold over a chunk of rows, calling the supplied fold-like function
-- on each row as it is received. In case the cursor is exhausted,
Expand Down
109 changes: 71 additions & 38 deletions src/Database/PostgreSQL/Simple/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE CPP, BangPatterns, DoAndIfThenElse, RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable, DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}

------------------------------------------------------------------------------
-- |
Expand All @@ -24,7 +25,7 @@ module Database.PostgreSQL.Simple.Internal where
import Control.Applicative
import Control.Exception
import Control.Concurrent.MVar
import Control.Monad(MonadPlus(..))
import Control.Monad(MonadPlus(..), when)
import Data.ByteString(ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
Expand Down Expand Up @@ -329,40 +330,70 @@ exec conn sql =
withConnection conn $ \h -> do
success <- PQ.sendQuery h sql
if success
then awaitResult h Nothing
then do
mfd <- PQ.socket h
case mfd of
Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec"
Just socket ->
-- Here we assume any exceptions are asynchronous, or that
-- they are not from libpq. If an error happens in libpq
-- (e.g. the query being canceled or session terminated),
-- libpq will not throw, but will instead return a Result
-- indicating an error
uninterruptibleMask $ \restore ->
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might get an async exception after sendQuery but before uninterruptibleMask. Probably the whole do block needs mask around it with restore around sendQuery. Sorry if I'm overlooking something here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, thanks!

restore (consumeUntilNotBusy h socket >> getResult h Nothing)
`onException` cancelAndClear h socket
else throwLibPQError h "PQsendQuery failed"
where
awaitResult h mres = do
mfd <- PQ.socket h
case mfd of
Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec"
Just fd -> do
threadWaitRead fd
_ <- PQ.consumeInput h -- FIXME?
getResult h mres
cancelAndClear h socket = do
mcncl <- PQ.getCancel h
case mcncl of
Nothing -> pure ()
Just cncl -> do
cancelStatus <- PQ.cancel cncl
case cancelStatus of
Left _ -> PQ.errorMessage h >>= \mmsg -> throwLibPQError h ("Database.PostgreSQL.Simple.Internal.cancelAndClear: " <> fromMaybe "Unknown error" mmsg)
Right () -> do
consumeUntilNotBusy h socket
waitForNullResult h

waitForNullResult h = do
mres <- PQ.getResult h
case mres of
Nothing -> pure ()
Just _ -> waitForNullResult h

-- | Waits until results are ready to be fetched.
consumeUntilNotBusy h socket = do
-- According to https://www.postgresql.org/docs/current/libpq-async.html :
-- 1. The isBusy status only changes by calling PQConsumeInput
-- 2. In case of errors, "PQgetResult should be called until it returns a null pointer, to allow libpq to process the error information completely"
-- 3. Also, "A typical application using these functions will have a main loop that uses select() or poll() ... When the main loop detects input ready, it should call PQconsumeInput to read the input. It can then call PQisBusy, followed by PQgetResult if PQisBusy returns false (0)"
busy <- PQ.isBusy h
when busy $ do
threadWaitRead socket
someError <- not <$> PQ.consumeInput h
when someError $ PQ.errorMessage h >>= \mmsg -> throwLibPQError h ("Database.PostgreSQL.Simple.Internal.consumeUntilNotBusy: " <> fromMaybe "Unknown error" mmsg)
consumeUntilNotBusy h socket

getResult h mres = do
isBusy <- PQ.isBusy h
if isBusy
then awaitResult h mres
else do
mres' <- PQ.getResult h
case mres' of
Nothing -> case mres of
Nothing -> throwLibPQError h "PQgetResult returned no results"
Just res -> return res
Just res -> do
status <- PQ.resultStatus res
case status of
-- FIXME: handle PQ.CopyBoth and PQ.SingleTuple
PQ.EmptyQuery -> getResult h mres'
PQ.CommandOk -> getResult h mres'
PQ.TuplesOk -> getResult h mres'
PQ.CopyOut -> return res
PQ.CopyIn -> return res
PQ.BadResponse -> getResult h mres'
PQ.NonfatalError -> getResult h mres'
PQ.FatalError -> getResult h mres'
mres' <- PQ.getResult h
case mres' of
Nothing -> case mres of
Nothing -> throwLibPQError h "PQgetResult returned no results"
Just res -> return res
Just res -> do
status <- PQ.resultStatus res
case status of
-- FIXME: handle PQ.CopyBoth and PQ.SingleTuple
PQ.EmptyQuery -> getResult h mres'
PQ.CommandOk -> getResult h mres'
PQ.TuplesOk -> getResult h mres'
PQ.CopyOut -> return res
PQ.CopyIn -> return res
PQ.BadResponse -> getResult h mres'
PQ.NonfatalError -> getResult h mres'
PQ.FatalError -> getResult h mres'
#endif

-- | A version of 'execute' that does not perform query substitution.
Expand All @@ -372,7 +403,7 @@ execute_ conn q@(Query stmt) = do
finishExecute conn q result

finishExecute :: Connection -> Query -> PQ.Result -> IO Int64
finishExecute _conn q result = do
finishExecute conn q result = do
status <- PQ.resultStatus result
case status of
-- FIXME: handle PQ.CopyBoth and PQ.SingleTuple
Expand All @@ -395,9 +426,9 @@ finishExecute _conn q result = do
throwIO $ QueryError "execute: COPY TO is not supported" q
PQ.CopyIn ->
throwIO $ QueryError "execute: COPY FROM is not supported" q
PQ.BadResponse -> throwResultError "execute" result status
PQ.NonfatalError -> throwResultError "execute" result status
PQ.FatalError -> throwResultError "execute" result status
PQ.BadResponse -> throwResultError "execute" conn result status
PQ.NonfatalError -> throwResultError "execute" conn result status
PQ.FatalError -> throwResultError "execute" conn result status
where
mkInteger str = B8.foldl' delta 0 str
where
Expand All @@ -406,9 +437,11 @@ finishExecute _conn q result = do
then 10 * acc + fromIntegral (ord c - ord '0')
else error ("finishExecute: not an int: " ++ B8.unpack str)

throwResultError :: ByteString -> PQ.Result -> PQ.ExecStatus -> IO a
throwResultError _ result status = do
errormsg <- fromMaybe "" <$>
throwResultError :: ByteString -> Connection -> PQ.Result -> PQ.ExecStatus -> IO a
throwResultError _ conn result status = do
-- Some errors only exist in "errorMessage"
mConnectionError <- withConnection conn PQ.errorMessage
errormsg <- fromMaybe "" . (mConnectionError <|>) <$>
PQ.resultErrorField result PQ.DiagMessagePrimary
detail <- fromMaybe "" <$>
PQ.resultErrorField result PQ.DiagMessageDetail
Expand Down
16 changes: 8 additions & 8 deletions src/Database/PostgreSQL/Simple/Internal/PQResultUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict

finishQueryWith :: RowParser r -> Connection -> Query -> PQ.Result -> IO [r]
finishQueryWith parser conn q result = finishQueryWith' q result $ do
finishQueryWith parser conn q result = finishQueryWith' conn q result $ do
nrows <- PQ.ntuples result
ncols <- PQ.nfields result
forM' 0 (nrows-1) $ \row ->
getRowWith parser row ncols conn result

finishQueryWithV :: RowParser r -> Connection -> Query -> PQ.Result -> IO (V.Vector r)
finishQueryWithV parser conn q result = finishQueryWith' q result $ do
finishQueryWithV parser conn q result = finishQueryWith' conn q result $ do
nrows <- PQ.ntuples result
let PQ.Row nrows' = nrows
ncols <- PQ.nfields result
Expand All @@ -56,7 +56,7 @@ finishQueryWithV parser conn q result = finishQueryWith' q result $ do
V.unsafeFreeze mv

finishQueryWithVU :: VU.Unbox r => RowParser r -> Connection -> Query -> PQ.Result -> IO (VU.Vector r)
finishQueryWithVU parser conn q result = finishQueryWith' q result $ do
finishQueryWithVU parser conn q result = finishQueryWith' conn q result $ do
nrows <- PQ.ntuples result
let PQ.Row nrows' = nrows
ncols <- PQ.nfields result
Expand All @@ -67,8 +67,8 @@ finishQueryWithVU parser conn q result = finishQueryWith' q result $ do
MVU.unsafeWrite mv (fromIntegral row') value
VU.unsafeFreeze mv

finishQueryWith' :: Query -> PQ.Result -> IO a -> IO a
finishQueryWith' q result k = do
finishQueryWith' :: Connection -> Query -> PQ.Result -> IO a -> IO a
finishQueryWith' conn q result k = do
status <- PQ.resultStatus result
case status of
PQ.TuplesOk -> k
Expand All @@ -82,9 +82,9 @@ finishQueryWith' q result k = do
#if MIN_VERSION_postgresql_libpq(0,9,2)
PQ.SingleTuple -> queryErr "query: single-row mode is not supported"
#endif
PQ.BadResponse -> throwResultError "query" result status
PQ.NonfatalError -> throwResultError "query" result status
PQ.FatalError -> throwResultError "query" result status
PQ.BadResponse -> throwResultError "query" conn result status
PQ.NonfatalError -> throwResultError "query" conn result status
PQ.FatalError -> throwResultError "query" conn result status
where
queryErr msg = throwIO $ QueryError msg q

Expand Down
103 changes: 102 additions & 1 deletion test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import Database.PostgreSQL.Simple.Types(Query(..),Values(..), PGArray(..))
import qualified Database.PostgreSQL.Simple.Transaction as ST

import Control.Applicative
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (withAsync, wait)
import Control.Exception as E
import Control.Monad
import Data.Char
import Data.Foldable (toList)
import Data.List (concat, sort)
import Data.List (concat, sort, isInfixOf)
import Data.IORef
import Data.Monoid ((<>))
import Data.String (fromString)
Expand All @@ -48,6 +50,7 @@ import System.FilePath
import System.Timeout(timeout)
import Data.Time.Compat (getCurrentTime, diffUTCTime)
import System.Environment (getEnvironment)
import qualified System.IO as IO

import Test.Tasty
import Test.Tasty.Golden
Expand Down Expand Up @@ -84,6 +87,10 @@ tests env = testGroup "tests"
, testCase "2-ary generic" . testGeneric2
, testCase "3-ary generic" . testGeneric3
, testCase "Timeout" . testTimeout
, testCase "Expected user exceptions" . testExpectedExceptions
, testCase "Async exceptions" . testAsyncExceptionFailure
, testCase "Query canceled" . testCanceledQueryExceptions
, testCase "Connection terminated" . testConnectionTerminated
]

testBytea :: TestEnv -> TestTree
Expand Down Expand Up @@ -533,6 +540,98 @@ testDouble TestEnv{..} = do
[Only (x :: Double)] <- query_ conn "SELECT '-Infinity'::float8"
x @?= (-1 / 0)

-- | Specifies exceptions thrown by postgresql-simple for certain user errors.
testExpectedExceptions :: TestEnv -> Assertion
testExpectedExceptions TestEnv{..} = do
withConn $ \c -> do
execute_ c "SELECT 1,2" `shouldThrow` (\(e :: QueryError) -> "2-column result" `isInfixOf` show e)
execute_ c "SELECT 1/0" `shouldThrow` (\(e :: SqlError) -> sqlState e == "22012")
(query_ c "SELECT 1, 2, 3" :: IO [(String, Int)]) `shouldThrow` (\(e :: ResultError) -> errSQLType e == "int4" && errHaskellType e == "Text")

shouldThrow :: forall e a. Exception e => IO a -> (e -> Bool) -> IO ()
shouldThrow f pred = do
ea <- try f
assertBool "Exception is as expected" $ case ea of
Right _ -> False
Left (ex :: e) -> pred ex

-- | Ensures that asynchronous exceptions thrown while queries are executing
-- are handled properly.
testAsyncExceptionFailure :: TestEnv -> Assertion
testAsyncExceptionFailure TestEnv{..} = withConn $ \c -> do
-- We need to give it enough time to start executing the query
-- before timing out. One second should be more than enough
execute_ c "SET my.setting TO '42'"
testAsyncException c (1000 * 1000) (execute_ c "SELECT pg_sleep(60)")
testAsyncException c (1000 * 1000) $
bracket_ (execute_ c "CREATE TABLE IF NOT EXISTS copy_cancel (v INT)") (execute_ c "DROP TABLE IF EXISTS copy_cancel") $
bracket_ (copy_ c "COPY copy_cancel FROM STDIN (FORMAT CSV)") (putCopyEnd c) $ do
putCopyData c "1\n"
threadDelay (1000 * 1000 * 60)

where
testAsyncException c timeLimit f = do
tmt <- timeout timeLimit f
tmt @?= Nothing
-- Any other query should work now without errors.
number42 <- query_ c "SELECT current_setting('my.setting')"
number42 @?= [ Only ("42" :: String) ]

-- | Ensures that canceled queries don't invalidate the Connection and specifies how
-- they can be detected.
testCanceledQueryExceptions :: TestEnv -> Assertion
testCanceledQueryExceptions TestEnv{..} = do
withConn $ \c1 -> withConn $ \c2 -> do
[ Only (c1Pid :: Int) ] <- query_ c1 "SELECT pg_backend_pid()"
execute_ c1 "SET my.setting TO '42'"

testCancelation c1 c2 c1Pid execPgSleep $ \(ex :: SqlError) -> sqlState ex == "57014"

-- What should we expect when COPY is canceled and putCopyEnd runs? The same SqlError as above, perhaps? Right now,
-- detecting if a query was canceled involves detecting two distinct types of exception.
testCancelation c1 c2 c1Pid execCopy $ \(ex :: IOException) -> "Database.PostgreSQL.Simple.Copy.putCopyEnd: failed to parse command status" `isInfixOf` show ex
&& "ERROR: canceling statement due to user request" `isInfixOf` show ex

-- Any other query should work now without errors.
number42 <- query_ c1 "SELECT current_setting('my.setting')"
number42 @?= [ Only ("42" :: String) ]

where
execPgSleep c = execute_ c "SELECT pg_sleep(60)"
execCopy c =
bracket_ (execute_ c "CREATE TABLE IF NOT EXISTS copy_cancel (v INT)") (execute_ c "DROP TABLE IF EXISTS copy_cancel") $
bracket_ (copy_ c "COPY copy_cancel FROM STDIN (FORMAT CSV)") (putCopyEnd c) $ do
putCopyData c "1\n"
threadDelay (1000 * 1000 * 2)
-- putCopyEnd will run after pg_cancel_backend due to threadDelays
testCancelation c1 c2 cPid f exPred = withAsync (f c1) $ \longRunningAction -> do
-- We need to give it enough time to start executing the query
-- before canceling it. One second should be more than enough
threadDelay (1000 * 1000)
cancelResult <- query c2 "SELECT pg_cancel_backend(?)" (Only cPid)
cancelResult @?= [ Only True ]
wait longRunningAction `shouldThrow` exPred
-- Connection is still usable after query canceled
[ Only (cPidAgain :: Int) ] <- query_ c1 "SELECT pg_backend_pid()"
cPid @?= cPidAgain

-- | Ensures that a specific type of exception is thrown when
-- the connection is terminated abruptly.
testConnectionTerminated :: TestEnv -> Assertion
testConnectionTerminated TestEnv{..} = do
withConn $ \c1 -> withConn $ \c2 -> do
[ Only (c1Pid :: Int) ] <- query_ c1 "SELECT pg_backend_pid()"
withAsync (execute_ c1 "SELECT pg_sleep(60)") $ \pgSleep -> do
-- We need to give it enough time to start executing the query
-- before terminating it. One second should be more than enough
threadDelay (1000 * 1000)
cancelResult <- query c2 "SELECT pg_terminate_backend(?)" (Only c1Pid)
cancelResult @?= [ Only True ]
killedQuery <- try $ wait pgSleep
assertBool "Connection was terminated" $ case killedQuery of
Right _ -> False
Left (ex :: SqlError) -> ("server closed the connection unexpectedly" `isInfixOf` show (sqlErrorMsg ex))
&& sqlExecStatus ex == FatalError

testGeneric1 :: TestEnv -> Assertion
testGeneric1 TestEnv{..} = do
Expand Down Expand Up @@ -620,6 +719,8 @@ withTestEnv connstr cb =

main :: IO ()
main = withConnstring $ \connstring -> do
IO.hSetBuffering IO.stdout IO.NoBuffering
IO.hSetBuffering IO.stderr IO.NoBuffering
withTestEnv connstring (defaultMain . tests)

withConnstring :: (BS8.ByteString -> IO ()) -> IO ()
Expand Down