Skip to content

Commit

Permalink
Tests pass with conduits
Browse files Browse the repository at this point in the history
  • Loading branch information
snoyberg committed Dec 29, 2011
1 parent f2afff6 commit a56b1ee
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 167 deletions.
4 changes: 2 additions & 2 deletions package-list.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ pkgs=( ./pool
./persistent
./persistent-template
./persistent-sqlite
./persistent-postgresql
./persistent-mongoDB )
./persistent-postgresql )
#./persistent-mongoDB )

88 changes: 37 additions & 51 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@ import Data.Either (partitionEithers)
import Control.Arrow
import Data.List (sort, groupBy)
import Data.Function (on)
#if MIN_VERSION_monad_control(0, 3, 0)
import Control.Monad.Trans.Control (MonadBaseControl)
#define MBCIO MonadBaseControl IO
#else
import Control.Monad.IO.Control (MonadControlIO)
#define MBCIO MonadControlIO
#endif
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL

import Data.ByteString (ByteString)
import qualified Data.Text as T
Expand All @@ -43,16 +38,15 @@ import qualified Data.Text.Encoding.Error as T
import Data.Time.LocalTime (localTimeToUTC, utc)
import Data.Text (Text, pack, unpack)
import Data.Aeson
import Control.Monad (forM)
import Data.Neither (meither, MEither (..))
import Control.Monad (forM, mzero)

withPostgresqlPool :: (MBCIO m, MonadIO m)
withPostgresqlPool :: C.ResourceIO m
=> T.Text
-> Int -- ^ number of connections to open
-> (ConnectionPool -> m a) -> m a
withPostgresqlPool s = withSqlPool $ open' s

withPostgresqlConn :: (MBCIO m, MonadIO m) => T.Text -> (Connection -> m a) -> m a
withPostgresqlConn :: C.ResourceIO m => T.Text -> (Connection -> m a) -> m a
withPostgresqlConn = withSqlConn . open'

open' :: T.Text -> IO Connection
Expand Down Expand Up @@ -99,14 +93,18 @@ execute' stmt vals = do
return ()

withStmt'
:: (MBCIO m, MonadIO m)
:: C.ResourceIO m
=> H.Statement
-> [PersistValue]
-> (RowPopper m -> m a)
-> m a
withStmt' stmt vals f = do
_ <- liftIO $ H.execute stmt $ map pToSql vals
f $ liftIO $ (fmap . fmap) (map pFromSql) $ H.fetchRow stmt
-> C.Source m [PersistValue]
withStmt' stmt vals = C.sourceIO
(H.execute stmt (map pToSql vals) >> return ())
return
pull
where
pull () = do
x <- liftIO $ (fmap . fmap) (map pFromSql) $ H.fetchRow stmt
return $ maybe C.Closed C.Open x

pToSql :: PersistValue -> H.SqlValue
pToSql (PersistText t) = H.SqlString $ unpack t
Expand Down Expand Up @@ -196,34 +194,34 @@ getColumns getter def = do
[ PersistText $ unDBName $ entityDB def
, PersistText $ unDBName $ entityID def
]
cs <- withStmt stmt vals helper
cs <- C.runResourceT $ withStmt stmt vals C.$$ helper
stmt' <- getter
"SELECT constraint_name, column_name FROM information_schema.constraint_column_usage WHERE table_name=? AND column_name <> ? ORDER BY constraint_name, column_name"
us <- withStmt stmt' vals helperU
us <- C.runResourceT $ withStmt stmt' vals C.$$ helperU
return $ cs ++ us
where
getAll pop front = do
x <- pop
getAll front = do
x <- CL.head
case x of
Nothing -> return $ front []
Just [PersistByteString con, PersistByteString col] ->
getAll pop (front . (:) (bsToChars con, bsToChars col))
Just _ -> getAll pop front -- FIXME error message?
helperU pop = do
rows <- getAll pop id
getAll (front . (:) (bsToChars con, bsToChars col))
Just _ -> getAll front -- FIXME error message?
helperU = do
rows <- getAll id
return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
$ groupBy ((==) `on` fst)
$ map (T.pack *** T.pack) rows
helper pop = do
x <- pop
helper = do
x <- CL.head
case x of
Nothing -> return []
Just x' -> do
col <- getColumn getter (entityDB def) x'
col <- liftIO $ getColumn getter (entityDB def) x'
let col' = case col of
Left e -> Left e
Right c -> Right $ Left c
cols <- helper pop
cols <- helper
return $ col' : cols

getAlters :: ([Column], [(DBName, [DBName])])
Expand Down Expand Up @@ -279,11 +277,11 @@ getColumn getter tname
]
let ref = refName tname cname
stmt <- getter sql
withStmt stmt
C.runResourceT $ withStmt stmt
[ PersistText $ unDBName tname
, PersistText $ unDBName ref
] $ \pop -> do
Just [PersistInt64 i] <- pop
] C.$$ do
Just [PersistInt64 i] <- CL.head
return $ if i == 0 then Nothing else Just (DBName "", ref)
d' = case d of
PersistNull -> Right Nothing
Expand Down Expand Up @@ -491,31 +489,19 @@ instance PersistConfig PostgresConf where
type PersistConfigPool PostgresConf = ConnectionPool
withPool (PostgresConf cs size) = withPostgresqlPool cs size
runPool _ = runSqlPool
loadConfig e' = meither Left Right $ do
e <- go $ fromMapping e'
db <- go $ lookupScalar "database" e
pool' <- go $ lookupScalar "poolsize" e
pool <- safeRead "poolsize" pool'

-- TODO: default host/port?
loadConfig (Object o) = do
db <- o .: "database"
pool <- o .: "poolsize"
connparts <- forM ["user", "password", "host", "port"] $ \k -> do
v <- go $ lookupScalar k e
v <- o .: k
return $ T.concat [k, "=", v, " "]

-- TODO: default host/port?

let conn = T.concat connparts

return $ PostgresConf (T.concat [conn, " dbname=", db]) pool
where
go :: MEither ObjectExtractError a -> MEither String a
go (MLeft e) = MLeft $ show e
go (MRight a) = MRight a

safeRead :: String -> T.Text -> MEither String Int
safeRead name t = case reads s of
(i, _):_ -> MRight i
[] -> MLeft $ concat ["Invalid value for ", name, ": ", s]
where
s = T.unpack t
loadConfig _ = mzero

refName :: DBName -> DBName -> DBName
refName (DBName table) (DBName column) =
Expand Down
5 changes: 2 additions & 3 deletions persistent-postgresql/persistent-postgresql.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ library
, transformers >= 0.2.1 && < 0.3
, HDBC-postgresql >= 2.2.3.1 && < 2.4
, persistent >= 0.7 && < 0.8
, containers >= 0.2 && < 0.5
, containers >= 0.2
, bytestring >= 0.9 && < 0.10
, text >= 0.7 && < 0.12
, monad-control >= 0.2 && < 0.4
, time >= 1.1
, data-object >= 0.3 && < 0.4
, neither >= 0.3 && < 0.4
, aeson >= 0.5 && < 0.6
exposed-modules: Database.Persist.Postgresql
ghc-options: -Wall

Expand Down
67 changes: 28 additions & 39 deletions persistent-sqlite/Database/Persist/Sqlite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ import Control.Monad.IO.Control (MonadControlIO)
import Control.Exception.Control (finally)
#define MBCIO MonadControlIO
#endif
import Data.Text (Text, pack, unpack)
import Data.Neither (MEither (..), meither)
import Data.Text (Text, pack)
import Control.Monad (mzero)
import Data.Aeson
import qualified Data.Text as T
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
import Control.Applicative

withSqlitePool :: (MonadIO m, MBCIO m)
withSqlitePool :: C.ResourceIO m
=> Text
-> Int -- ^ number of connections to open
-> (ConnectionPool -> m a) -> m a
withSqlitePool s = withSqlPool $ open' s

withSqliteConn :: (MonadIO m, MBCIO m) => Text -> (Connection -> m a) -> m a
withSqliteConn :: C.ResourceIO m => Text -> (Connection -> m a) -> m a
withSqliteConn = withSqlConn . open'

open' :: Text -> IO Connection
Expand Down Expand Up @@ -100,23 +103,22 @@ execute' stmt vals = flip finally (liftIO $ Sqlite.reset stmt) $ do
return ()

withStmt'
:: (MBCIO m, MonadIO m)
:: C.ResourceIO m
=> Sqlite.Statement
-> [PersistValue]
-> (RowPopper m -> m a)
-> m a
withStmt' stmt vals f = flip finally (liftIO $ Sqlite.reset stmt) $ do
liftIO $ Sqlite.bind stmt vals
x <- f go
return x
-> C.Source m [PersistValue]
withStmt' stmt vals = C.sourceIO
(Sqlite.bind stmt vals >> return stmt)
Sqlite.reset
pull
where
go = liftIO $ do
pull _ = liftIO $ do
x <- Sqlite.step stmt
case x of
Sqlite.Done -> return Nothing
Sqlite.Done -> return C.Closed
Sqlite.Row -> do
cols <- liftIO $ Sqlite.columns stmt
return $ Just cols
return $ C.Open cols
showSqlType :: SqlType -> String
showSqlType SqlString = "VARCHAR"
showSqlType SqlInt32 = "INTEGER"
Expand All @@ -137,7 +139,8 @@ migrate' allDefs getter val = do
let (cols, uniqs) = mkColumns allDefs val
let newSql = mkCreateTable False def (cols, uniqs)
stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
oldSql' <- withStmt stmt [PersistText $ unDBName table] go
oldSql' <- C.runResourceT
$ withStmt stmt [PersistText $ unDBName table] C.$$ go
case oldSql' of
Nothing -> return $ Right [(False, newSql)]
Just oldSql ->
Expand All @@ -149,8 +152,8 @@ migrate' allDefs getter val = do
where
def = entityDef val
table = entityDB def
go pop = do
x <- pop
go = do
x <- CL.head
case x of
Nothing -> return Nothing
Just [PersistText y] -> return $ Just y
Expand All @@ -163,7 +166,7 @@ getCopyTable :: PersistEntity val
-> IO [(Bool, Sql)]
getCopyTable allDefs getter val = do
stmt <- getter $ pack $ "PRAGMA table_info(" ++ escape' table ++ ")"
oldCols' <- withStmt stmt [] getCols
oldCols' <- C.runResourceT $ withStmt stmt [] C.$$ getCols
let oldCols = map DBName $ filter (/= "id") oldCols' -- need to update for table id attribute ?
let newCols = map cName cols
let common = filter (`elem` oldCols) newCols
Expand All @@ -177,12 +180,12 @@ getCopyTable allDefs getter val = do
]
where
def = entityDef val
getCols pop = do
x <- pop
getCols = do
x <- CL.head
case x of
Nothing -> return []
Just (_:PersistText name:_) -> do
names <- getCols pop
names <- getCols
return $ name : names
Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y
table = entityDB def
Expand Down Expand Up @@ -273,24 +276,10 @@ instance PersistConfig SqliteConf where
type PersistConfigPool SqliteConf = ConnectionPool
withPool (SqliteConf cs size) = withSqlitePool cs size
runPool _ = runSqlPool
loadConfig e' = meither Left Right $ do
e <- go $ fromMapping e'
db <- go $ lookupScalar "database" e
pool' <- go $ lookupScalar "poolsize" e
pool <- safeRead "poolsize" pool'

return $ SqliteConf db pool
where
go :: MEither ObjectExtractError a -> MEither String a
go (MLeft e) = MLeft $ show e
go (MRight a) = MRight a

safeRead :: String -> Text -> MEither String Int
safeRead name t = case reads s of
(i, _):_ -> MRight i
[] -> MLeft $ concat ["Invalid value for ", name, ": ", s]
where
s = unpack t
loadConfig (Object o) =
SqliteConf <$> o .: "database"
<*> o .: "poolsize"
loadConfig _ = mzero

#if MIN_VERSION_monad_control(0, 3, 0)
finally :: MonadBaseControl IO m
Expand Down
3 changes: 1 addition & 2 deletions persistent-sqlite/persistent-sqlite.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ library
, monad-control >= 0.2 && < 0.4
, containers >= 0.2 && < 0.5
, text >= 0.7 && < 1
, data-object >= 0.3 && < 0.4
, neither >= 0.3 && < 0.4
, aeson >= 0.5
exposed-modules: Database.Sqlite
Database.Persist.Sqlite
ghc-options: -Wall
Expand Down
10 changes: 3 additions & 7 deletions persistent-test/PersistentTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ import qualified Control.Exception.Control as Control
#endif
import System.Random

import Control.Monad.Trans.Resource (ResourceIO)

#if WITH_POSTGRESQL
import Database.Persist.Postgresql
#endif
Expand Down Expand Up @@ -211,13 +213,7 @@ instance Arbitrary PersistValue where
type BackendMonad = SqlPersist
sqlite_database :: Text
sqlite_database = "test/testdb.sqlite3"
runConn ::
#if MIN_VERSION_monad_control(0, 3, 0)
(Control.Monad.Trans.Control.MonadBaseControl IO m, MonadIO m)
#else
Control.Monad.IO.Control.MonadControlIO m
#endif
=> SqlPersist m t -> m ()
runConn :: ResourceIO m => SqlPersist m t -> m ()
runConn f = do
_<-withSqlitePool sqlite_database 1 $ runSqlPool f
#if WITH_POSTGRESQL
Expand Down
20 changes: 5 additions & 15 deletions persistent-test/RenameTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ import Database.Persist.GenericSql.Raw
#if WITH_POSTGRESQL
import Database.Persist.Postgresql
#endif
#if MIN_VERSION_monad_control(0, 3, 0)
import qualified Control.Monad.Trans.Control
#else
import qualified Control.Monad.IO.Control
#endif
import Control.Monad.IO.Class (MonadIO)
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL

-- Test lower case names
share [mkPersist sqlMkSettings, mkMigrate "lowerCase"] [persistLowerCase|
Expand All @@ -32,13 +28,7 @@ RefTable
UniqueRefTable someVal
|]

runConn2 ::
#if MIN_VERSION_monad_control(0, 3, 0)
(Control.Monad.Trans.Control.MonadBaseControl IO m, MonadIO m)
#else
Control.Monad.IO.Control.MonadControlIO m
#endif
=> SqlPersist m t -> m ()
runConn2 :: C.ResourceIO m => SqlPersist m t -> m ()
runConn2 f = do
_ <- withSqlitePool ":memory:" 1 $ runSqlPool f
#if WITH_POSTGRESQL
Expand All @@ -51,8 +41,8 @@ renameSpecs = describe "rename specs" $ do
it "handles lower casing" $ asIO $ do
runConn2 $ do
_ <- runMigrationSilent lowerCase
withStmt "SELECT full_name from lower_case_table WHERE my_id=5" [] $ const $ return ()
withStmt "SELECT something_else from ref_table WHERE id=4" [] $ const $ return ()
C.runResourceT $ withStmt "SELECT full_name from lower_case_table WHERE my_id=5" [] C.$$ CL.sinkNull
C.runResourceT $ withStmt "SELECT something_else from ref_table WHERE id=4" [] C.$$ CL.sinkNull

asIO :: IO a -> IO a
asIO = id
Loading

0 comments on commit a56b1ee

Please sign in to comment.