Skip to content

Commit

Permalink
Move from enumerator to conduit (incomplete)
Browse files Browse the repository at this point in the history
  • Loading branch information
snoyberg committed Dec 28, 2011
1 parent dfda533 commit f2afff6
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 114 deletions.
2 changes: 1 addition & 1 deletion persistent-mongoDB/persistent-mongoDB.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ library
, transformers >= 0.2.1 && < 0.3
, containers >= 0.2 && < 0.5
, bytestring >= 0.9 && < 0.10
, enumerator >= 0.4 && < 0.5
, conduit >= 0.0 && < 0.1
, mongoDB >= 1.1 && < 1.2
, bson >= 0.1.6
, network >= 2.2.1.7 && < 3
Expand Down
2 changes: 1 addition & 1 deletion persistent-postgresql/Database/Persist/Postgresql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import Data.Time.LocalTime (localTimeToUTC, utc)
import Data.Text (Text, pack, unpack)
import Data.Object
import Data.Aeson
import Control.Monad (forM)
import Data.Neither (meither, MEither (..))

Expand Down
2 changes: 1 addition & 1 deletion persistent-sqlite/Database/Persist/Sqlite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import Control.Exception.Control (finally)
#endif
import Data.Text (Text, pack, unpack)
import Data.Neither (MEither (..), meither)
import Data.Object
import Data.Aeson
import qualified Data.Text as T

withSqlitePool :: (MonadIO m, MBCIO m)
Expand Down
5 changes: 3 additions & 2 deletions persistent-test/persistent-test.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ library
, template-haskell >= 2.4 && < 2.7
, HDBC-postgresql
, HDBC
, data-object
, aeson
, lifted-base
, neither
-- mongoDB dependencies
--, mongoDB == 1.1.*
Expand All @@ -65,7 +66,7 @@ library
, monad-control
, containers
, bytestring
, enumerator
, conduit
, time >= 1.2
, random == 1.*
, QuickCheck == 2.4.*
Expand Down
28 changes: 12 additions & 16 deletions persistent/Database/Persist/GenericSql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ import Web.PathPieces (PathPiece (..))
import qualified Data.Text.Read
import Data.Monoid (Monoid, mappend)
import Database.Persist.EntityDef
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL

type ConnectionPool = Pool Connection

Expand All @@ -71,12 +73,6 @@ instance PathPiece (Key SqlPersist entity) where
Right (i, "") -> Just $ Key $ PersistInt64 i
_ -> Nothing

withStmt'
:: (MBCIO m, MonadIO m)
=> Text -> [PersistValue]
-> (RowPopper (SqlPersist m) -> SqlPersist m a) -> SqlPersist m a
withStmt' = R.withStmt

execute' :: MonadIO m => Text -> [PersistValue] -> SqlPersist m ()
execute' = R.execute

Expand All @@ -93,19 +89,19 @@ runSqlConn (SqlPersist r) conn = do
liftIO $ commitC conn getter
return x

instance (MonadIO m, MBCIO m) => PersistStore SqlPersist m where
instance C.ResourceIO m => PersistStore SqlPersist m where
insert val = do
conn <- SqlPersist ask
let esql = insertSql conn (entityDB t) (map fieldDB $ entityFields t)
i <-
case esql of
Left sql -> withStmt' sql vals $ \pop -> do
Just [PersistInt64 i] <- pop
Left sql -> C.runResourceT $ R.withStmt sql vals C.$$ do
Just [PersistInt64 i] <- CL.head
return i
Right (sql1, sql2) -> do
execute' sql1 vals
withStmt' sql2 [] $ \pop -> do
Just [PersistInt64 i] <- pop
C.runResourceT $ R.withStmt sql2 [] C.$$ do
Just [PersistInt64 i] <- CL.head
return i
return $ Key $ PersistInt64 i
where
Expand Down Expand Up @@ -139,8 +135,8 @@ instance (MonadIO m, MBCIO m) => PersistStore SqlPersist m where
, escapeName conn $ entityDB t
, " WHERE id=?"
]
withStmt' sql [unKey k] $ \pop -> do
res <- pop
C.runResourceT $ R.withStmt sql [unKey k] C.$$ do
res <- CL.head
case res of
Nothing -> return Nothing
Just vals ->
Expand All @@ -159,7 +155,7 @@ instance (MonadIO m, MBCIO m) => PersistStore SqlPersist m where
, " WHERE id=?"
]

instance (MonadIO m, MBCIO m) => PersistUnique SqlPersist m where
instance C.ResourceIO m => PersistUnique SqlPersist m where
deleteBy uniq = do
conn <- SqlPersist ask
execute' (sql conn) $ persistUniqueToValues uniq
Expand All @@ -186,8 +182,8 @@ instance (MonadIO m, MBCIO m) => PersistUnique SqlPersist m where
, " WHERE "
, sqlClause conn
]
withStmt' sql (persistUniqueToValues uniq) $ \pop -> do
row <- pop
C.runResourceT $ R.withStmt sql (persistUniqueToValues uniq) C.$$ do
row <- CL.head
case row of
Nothing -> return Nothing
Just (PersistInt64 k:vals) ->
Expand Down
37 changes: 8 additions & 29 deletions persistent/Database/Persist/GenericSql/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ module Database.Persist.GenericSql.Internal
, Statement (..)
, withSqlConn
, withSqlPool
, RowPopper
, mkColumns
, Column (..)
, dummyFromFilts
Expand All @@ -24,23 +23,13 @@ import Control.Monad.IO.Class
import Data.Pool
import Database.Persist.Store
import Database.Persist.Query
#if MIN_VERSION_monad_control(0, 3, 0)
import Control.Monad.Trans.Control (MonadBaseControl, control, restoreM)
import qualified Control.Exception as E
#define MBCIO MonadBaseControl IO
#else
import Control.Monad.IO.Control (MonadControlIO)
import Control.Exception.Control (bracket)

#define MBCIO MonadControlIO
#endif
import Control.Exception.Lifted (bracket)
import Database.Persist.Util (nullable)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Monoid (Monoid, mappend, mconcat)
import Database.Persist.EntityDef

type RowPopper m = m (Maybe [PersistValue])
import qualified Data.Conduit as C

data Connection = Connection
{ prepare :: Text -> IO Statement
Expand All @@ -63,15 +52,17 @@ data Statement = Statement
{ finalize :: IO ()
, reset :: IO ()
, execute :: [PersistValue] -> IO ()
, withStmt :: forall a m. (MBCIO m, MonadIO m)
=> [PersistValue] -> (RowPopper m -> m a) -> m a
, withStmt :: forall m. C.ResourceIO m
=> [PersistValue]
-> C.Source m [PersistValue]
}

withSqlPool :: (MonadIO m, MBCIO m)
withSqlPool :: C.ResourceIO m
=> IO Connection -> Int -> (Pool Connection -> m a) -> m a
withSqlPool mkConn = createPool mkConn close'

withSqlConn :: (MonadIO m, MBCIO m) => IO Connection -> (Connection -> m a) -> m a
withSqlConn :: C.ResourceIO m
=> IO Connection -> (Connection -> m a) -> m a
withSqlConn open = bracket (liftIO open) (liftIO . close')

close' :: Connection -> IO ()
Expand Down Expand Up @@ -199,18 +190,6 @@ orderClause includeTable conn o =
else id)
$ escapeName conn $ fieldDB $ persistFieldDef x

#if MIN_VERSION_monad_control(0, 3, 0)
bracket :: MonadBaseControl IO m
=> m a -- ^ computation to run first (\"acquire resource\")
-> (a -> m b) -- ^ computation to run last (\"release resource\")
-> (a -> m c) -- ^ computation to run in-between
-> m c
bracket before after thing = control $ \runInIO ->
E.bracket (runInIO before)
(\st -> runInIO $ restoreM st >>= after)
(\st -> runInIO $ restoreM st >>= thing)
#endif

infixr 5 ++
(++) :: Text -> Text -> Text
(++) = mappend
30 changes: 23 additions & 7 deletions persistent/Database/Persist/GenericSql/Raw.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import Control.Monad.IO.Control (MonadControlIO)
#endif
import Data.Text (Text)
import Control.Monad (MonadPlus)
import Control.Monad.Trans.Resource (ResourceThrow (..), ResourceIO)
import qualified Data.Conduit as C

newtype SqlPersist m a = SqlPersist { unSqlPersist :: ReaderT Connection m a }
deriving (Monad, MonadIO, MonadTrans, Functor, Applicative, MonadPlus
Expand All @@ -42,6 +44,9 @@ newtype SqlPersist m a = SqlPersist { unSqlPersist :: ReaderT Connection m a }
#endif
)

instance ResourceThrow m => ResourceThrow (SqlPersist m) where
resourceThrow = lift . resourceThrow

instance MonadBase b m => MonadBase b (SqlPersist m) where
liftBase = lift . liftBase

Expand All @@ -56,13 +61,24 @@ instance MonadTransControl SqlPersist where
restoreT = SqlPersist . ReaderT . const . liftM unStReader
#endif

withStmt :: (MonadIO m, MBCIO m) => Text -> [PersistValue]
-> (RowPopper (SqlPersist m) -> SqlPersist m a) -> SqlPersist m a
withStmt sql vals pop = do
stmt <- getStmt sql
ret <- I.withStmt stmt vals pop
liftIO $ reset stmt
return ret
withStmt :: ResourceIO m
=> Text
-> [PersistValue]
-> C.Source (SqlPersist m) [PersistValue]
withStmt sql vals = C.Source $ do
stmt <- lift $ getStmt sql
src <- C.prepareSource $ I.withStmt stmt vals
return C.PreparedSource
{ C.sourcePull = do
res <- C.sourcePull src
case res of
C.Closed -> liftIO $ I.reset stmt
_ -> return ()
return res
, C.sourceClose = do
liftIO $ I.reset stmt
C.sourceClose src
}

execute :: MonadIO m => Text -> [PersistValue] -> SqlPersist m ()
execute sql vals = do
Expand Down
37 changes: 10 additions & 27 deletions persistent/Database/Persist/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@ module Database.Persist.Query
import Database.Persist.Store
import Database.Persist.EntityDef

import Data.Enumerator hiding (consume, map)
import Data.Enumerator.List (consume)
import qualified Data.Enumerator.List as EL

import qualified Control.Monad.IO.Class as Trans
import qualified Control.Exception as E
import Data.Text (pack)

import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL

infixr 3 =., +=., -=., *=., /=.
(=.), (+=.), (-=.), (*=.), (/=.) :: forall v typ. PersistField typ => EntityField v typ -> typ -> Update v
Expand Down Expand Up @@ -68,7 +62,7 @@ infixl 3 ||.
a ||. b = [FilterOr [FilterAnd a, FilterAnd b]]


class (PersistStore b m) => PersistQuery b m where
class PersistStore b m => PersistQuery b m where
-- | Update individual fields on a specific record.
update :: PersistEntity val => Key b val -> [Update val] -> b m ()

Expand All @@ -80,24 +74,25 @@ class (PersistStore b m) => PersistQuery b m where

-- | Get all records matching the given criterion in the specified order.
-- Returns also the identifiers.
selectEnum
selectSource
:: PersistEntity val
=> [Filter val]
-> [SelectOpt val]
-> Enumerator (Key b val, val) (b m) a
-> C.Source (b m) (Key b val, val)

-- | get just the first record for the criterion
selectFirst :: PersistEntity val
=> [Filter val]
-> [SelectOpt val]
-> b m (Maybe (Key b val, val))
selectFirst filts opts = run_ $ selectEnum filts ((LimitTo 1):opts) ==<< EL.head
selectFirst filts opts = C.runResourceT
$ selectSource filts ((LimitTo 1):opts) C.$$ CL.head


-- | Get the 'Key's of all records matching the given criterion.
selectKeys :: PersistEntity val
=> [Filter val]
-> Enumerator (Key b val) (b m) a
-> C.Source (b m) (Key b val)

-- | The total number of records fulfilling the given criterion.
count :: PersistEntity val => [Filter val] -> b m Int
Expand Down Expand Up @@ -128,11 +123,7 @@ selectList :: (PersistEntity val, PersistQuery b m)
=> [Filter val]
-> [SelectOpt val]
-> b m [(Key b val, val)]
selectList a b = do
res <- run $ selectEnum a b ==<< consume
case res of
Left e -> Trans.liftIO . E.throwIO $ PersistError $ pack $ show e
Right x -> return x
selectList a b = C.runResourceT $ selectSource a b C.$$ CL.consume

data SelectOpt v = forall typ. Asc (EntityField v typ)
| forall typ. Desc (EntityField v typ)
Expand All @@ -143,15 +134,7 @@ data SelectOpt v = forall typ. Asc (EntityField v typ)
deleteCascadeWhere :: (DeleteCascade a b m, PersistQuery b m)
=> [Filter a] -> b m ()
deleteCascadeWhere filts = do
res <- run $ selectKeys filts $ Continue iter
case res of
Left e -> Trans.liftIO . E.throwIO $ PersistError $ pack $ show e
Right () -> return ()
where
iter EOF = Iteratee $ return $ Yield () EOF
iter (Chunks keys) = Iteratee $ do
mapM_ deleteCascade keys
return $ Continue iter
C.runResourceT $ selectKeys filts C.$$ CL.mapM_ deleteCascade

data PersistUpdate = Assign | Add | Subtract | Multiply | Divide -- FIXME need something else here
deriving (Read, Show, Enum, Bounded)
Expand Down
Loading

0 comments on commit f2afff6

Please sign in to comment.