From 32e01a1ff81ef0502305aea78f86b5334f52db42 Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Fri, 27 Jan 2012 17:37:55 -0200 Subject: [PATCH 1/7] Avoid field name "int" on the tests. --- persistent-test/DataTypeTest.hs | 4 ++-- persistent-test/PersistentTest.hs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/persistent-test/DataTypeTest.hs b/persistent-test/DataTypeTest.hs index 9ba840ad3..8e0b6bed7 100644 --- a/persistent-test/DataTypeTest.hs +++ b/persistent-test/DataTypeTest.hs @@ -33,7 +33,7 @@ share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase| DataTypeTable text Text bytes ByteString - int Int + intx Int double Double bool Bool day Day @@ -58,7 +58,7 @@ dataTypeSpecs = describe "data type specs" $ do -- Check individual fields for better error messages check "text" dataTypeTableText check "bytes" dataTypeTableBytes - check "int" dataTypeTableInt + check "int" dataTypeTableIntx check "bool" dataTypeTableBool check "day" dataTypeTableDay check "time" dataTypeTableTime diff --git a/persistent-test/PersistentTest.hs b/persistent-test/PersistentTest.hs index 600607ce4..783f2bd14 100644 --- a/persistent-test/PersistentTest.hs +++ b/persistent-test/PersistentTest.hs @@ -146,7 +146,7 @@ share [mkPersist sqlSettings, mkMigrate "testMigrate", mkDeleteCascade] [persis NeedsPet petKey PetId Number - int Int + intx Int int32 Int32 word32 Word32 int64 Int64 From cf25b86b4ee6387d51a769aea0eac83a3a6ad9ce Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Mon, 30 Jan 2012 09:30:55 -0200 Subject: [PATCH 2/7] Avoid field name "double" on the tests. --- persistent-test/DataTypeTest.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/persistent-test/DataTypeTest.hs b/persistent-test/DataTypeTest.hs index 8e0b6bed7..12e81ad6a 100644 --- a/persistent-test/DataTypeTest.hs +++ b/persistent-test/DataTypeTest.hs @@ -34,7 +34,7 @@ DataTypeTable text Text bytes ByteString intx Int - double Double + doublex Double bool Bool day Day time TimeOfDay @@ -66,8 +66,8 @@ dataTypeSpecs = describe "data type specs" $ do -- Do a special check for Double since it may -- lose precision when serialized. - when (abs (dataTypeTableDouble x - dataTypeTableDouble y) > 1e-14) $ - check "double" dataTypeTableDouble + when (abs (dataTypeTableDoublex x - dataTypeTableDoublex y) > 1e-14) $ + check "double" dataTypeTableDoublex randomValue :: IO DataTypeTable randomValue = DataTypeTable From 12a5eb027b98f4c5677415b8ab09264026e48e9e Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Mon, 30 Jan 2012 10:29:54 -0200 Subject: [PATCH 3/7] Avoid Text with chars outside BMP on the tests. --- persistent-test/DataTypeTest.hs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/persistent-test/DataTypeTest.hs b/persistent-test/DataTypeTest.hs index 12e81ad6a..dc4497abc 100644 --- a/persistent-test/DataTypeTest.hs +++ b/persistent-test/DataTypeTest.hs @@ -16,6 +16,7 @@ import Database.Persist.TH #if WITH_POSTGRESQL import Database.Persist.Postgresql #endif +import Data.Char (generalCategory, GeneralCategory(..)) import Data.Text (Text) import qualified Data.Text as T import Data.ByteString (ByteString) @@ -71,7 +72,11 @@ dataTypeSpecs = describe "data type specs" $ do randomValue :: IO DataTypeTable randomValue = DataTypeTable - <$> (T.pack . filter (/= '\0') <$> randomIOs) + <$> (T.pack + . filter ((`notElem` forbidden) . generalCategory) + . filter (<= '\xFFFF') -- only BMP + . filter (/= '\0') -- no nulls + <$> randomIOs) <*> (S.pack . map intToWord8 <$> randomIOs) <*> randomIO <*> randomIO @@ -79,6 +84,7 @@ randomValue = DataTypeTable <*> randomDay <*> randomTime <*> randomUTC + where forbidden = [NotAssigned, PrivateUse] asIO :: IO a -> IO a asIO = id From ad3e6ed3e52d71056ad8b08fe314fb540bc9945b Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Mon, 30 Jan 2012 09:08:17 -0200 Subject: [PATCH 4/7] Add Ord instance to SqlType. --- persistent/Database/Persist/Store.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/persistent/Database/Persist/Store.hs b/persistent/Database/Persist/Store.hs index fd34bda41..9687e1683 100644 --- a/persistent/Database/Persist/Store.hs +++ b/persistent/Database/Persist/Store.hs @@ -134,7 +134,7 @@ data SqlType = SqlString | SqlTime | SqlDayTime | SqlBlob - deriving (Show, Read, Eq, Typeable) + deriving (Show, Read, Eq, Typeable, Ord) -- | A value which can be marshalled to and from a 'PersistValue'. class PersistField a where From bf1b4e4599667c5813c89113f758b7ab4ac2d1f5 Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Mon, 30 Jan 2012 09:08:34 -0200 Subject: [PATCH 5/7] Add Eq, Ord and Show instances to Column. --- persistent/Database/Persist/GenericSql/Internal.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/persistent/Database/Persist/GenericSql/Internal.hs b/persistent/Database/Persist/GenericSql/Internal.hs index ce8529a3e..b1392f3ec 100644 --- a/persistent/Database/Persist/GenericSql/Internal.hs +++ b/persistent/Database/Persist/GenericSql/Internal.hs @@ -140,6 +140,7 @@ data Column = Column , cDefault :: Maybe Text , cReference :: (Maybe (DBName, DBName)) -- table name, constraint name } + deriving (Eq, Ord, Show) {- FIXME getSqlValue :: [String] -> Maybe String From 1bde145ca27cfa6ab71cd02f394985dbae41e0bf Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Mon, 30 Jan 2012 10:11:03 -0200 Subject: [PATCH 6/7] Correctly escape rawSql queries in PersistentTest.hs. MySQL does not accept table names with double quotes. --- persistent-test/PersistentTest.hs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/persistent-test/PersistentTest.hs b/persistent-test/PersistentTest.hs index 783f2bd14..cb0469d1c 100644 --- a/persistent-test/PersistentTest.hs +++ b/persistent-test/PersistentTest.hs @@ -42,12 +42,15 @@ import Control.Monad (replicateM) import qualified Data.ByteString as BS #else -import Database.Persist.EntityDef (EntityDef(..)) +import Database.Persist.EntityDef (EntityDef(..), DBName(..)) import Database.Persist.Store ( DeleteCascade (..) ) import Database.Persist.GenericSql +import Database.Persist.GenericSql.Internal (escapeName) import qualified Database.Persist.Query.Join.Sql import Database.Persist.Sqlite import Control.Exception (SomeException) +import Control.Monad.Trans.Reader (ask) +import qualified Data.Text as T #if MIN_VERSION_monad_control(0, 3, 0) import qualified Control.Exception as E #define CATCH catch' @@ -746,7 +749,16 @@ specs = describe "persistent" $ do (a2k, a2) <- insert' $ Pet p1k "Zeno" Cat (a3k, a3) <- insert' $ Pet p2k "Lhama" Dog (_ , _ ) <- insert' $ Pet p3k "Abacate" Cat - ret <- rawSql "SELECT ??, ?? FROM \"Person\", \"Pet\" WHERE \"Person\".age >= ? AND \"Pet\".\"ownerId\" = \"Person\".id ORDER BY \"Person\".name, \"Pet\".name" [PersistInt64 20] + escape <- ((. DBName) . escapeName) `fmap` SqlPersist ask + let query = T.concat [ "SELECT ??, ?? " + , "FROM ", escape "Person" + , ", ", escape "Pet" + , " WHERE ", escape "Person", ".", escape "age", " >= ? " + , "AND ", escape "Pet", ".", escape "ownerId", " = " + , escape "Person", ".", escape "id" + , " ORDER BY ", escape "Person", ".", escape "name" + ] + ret <- rawSql query [PersistInt64 20] liftIO $ ret @?= [ (Entity p1k p1, Entity a1k a1) , (Entity p1k p1, Entity a2k a2) , (Entity p2k p2, Entity a3k a3) ] @@ -754,8 +766,12 @@ specs = describe "persistent" $ do it "rawSql/order-proof" $ db $ do let p1 = Person "Zacarias" 93 Nothing p1k <- insert p1 - ret1 <- rawSql "SELECT ?? FROM \"Person\"" [] - ret2 <- rawSql "SELECT ?? FROM \"Person\"" [] + escape <- ((. DBName) . escapeName) `fmap` SqlPersist ask + let query = T.concat [ "SELECT ?? " + , "FROM ", escape "Person" + ] + ret1 <- rawSql query [] + ret2 <- rawSql query [] liftIO $ ret1 @?= [Entity p1k p1] liftIO $ ret2 @?= [Entity (Key $ unKey p1k) (RFO p1)] From f978314c2e388d1b7c35949846b314243005f529 Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Fri, 27 Jan 2012 16:08:16 -0200 Subject: [PATCH 7/7] New glorious package persistent-mysql. --- persistent-mysql/Database/Persist/MySQL.hs | 759 +++++++++++++++++++++ persistent-mysql/LICENSE | 25 + persistent-mysql/Setup.lhs | 7 + persistent-mysql/persistent-mysql.cabal | 44 ++ persistent-test/DataTypeTest.hs | 3 + persistent-test/PersistentTest.hs | 13 +- persistent-test/RenameTest.hs | 11 + persistent-test/persistent-mysql | 1 + persistent-test/persistent-test.cabal | 19 +- 9 files changed, 879 insertions(+), 3 deletions(-) create mode 100644 persistent-mysql/Database/Persist/MySQL.hs create mode 100644 persistent-mysql/LICENSE create mode 100755 persistent-mysql/Setup.lhs create mode 100644 persistent-mysql/persistent-mysql.cabal create mode 120000 persistent-test/persistent-mysql diff --git a/persistent-mysql/Database/Persist/MySQL.hs b/persistent-mysql/Database/Persist/MySQL.hs new file mode 100644 index 000000000..e9c37b20c --- /dev/null +++ b/persistent-mysql/Database/Persist/MySQL.hs @@ -0,0 +1,759 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE FlexibleContexts #-} +-- | A MySQL backend for @persistent@. +module Database.Persist.MySQL + ( withMySQLPool + , withMySQLConn + , createMySQLPool + , module Database.Persist + , module Database.Persist.GenericSql + , MySQL.ConnectInfo(..) + , MySQLBase.SSLInfo(..) + , MySQL.defaultConnectInfo + , MySQLBase.defaultSSLInfo + , MySQLConf(..) + ) where + +import Control.Arrow +import Control.Monad (mzero) +import Control.Monad.IO.Class (MonadIO (..)) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Error (ErrorT(..)) +import Data.Aeson +import Data.ByteString (ByteString) +import Data.Either (partitionEithers) +import Data.Function (on) +import Data.IORef +import Data.List (find, intercalate, sort, groupBy) +import Data.Text (Text, pack) +-- import Data.Time.LocalTime (localTimeToUTC, utc) +import System.Environment (getEnvironment) + +import qualified Data.Conduit as C +import qualified Data.Conduit.List as CL +import qualified Data.Map as Map +import qualified Data.Text as T +import qualified Data.Text.Encoding as T + +import Database.Persist hiding (Entity (..)) +import Database.Persist.Store +import Database.Persist.GenericSql hiding (Key(..)) +import Database.Persist.GenericSql.Internal +import Database.Persist.EntityDef + +import qualified Database.MySQL.Simple as MySQL +import qualified Database.MySQL.Simple.Param as MySQL +import qualified Database.MySQL.Simple.Result as MySQL +import qualified Database.MySQL.Simple.Types as MySQL + +import qualified Database.MySQL.Base as MySQLBase +import qualified Database.MySQL.Base.Types as MySQLBase + + + +-- | Create a MySQL connection pool and run the given action. +-- The pool is properly released after the action finishes using +-- it. Note that you should not use the given 'ConnectionPool' +-- outside the action since it may be already been released. +withMySQLPool :: MonadIO m => + MySQL.ConnectInfo + -- ^ Connection information. + -> Int + -- ^ Number of connections to be kept open in the pool. + -> (ConnectionPool -> m a) + -- ^ Action to be executed that uses the connection pool. + -> m a +withMySQLPool ci = withSqlPool $ open' ci + + +-- | Create a MySQL connection pool. Note that it's your +-- responsability to properly close the connection pool when +-- unneeded. Use 'withMySQLPool' for automatic resource control. +createMySQLPool :: MonadIO m => + MySQL.ConnectInfo + -- ^ Connection information. + -> Int + -- ^ Number of connections to be kept open in the pool. + -> m ConnectionPool +createMySQLPool ci = createSqlPool $ open' ci + + +-- | Same as 'withMySQLPool', but instead of opening a pool +-- of connections, only one connection is opened. +withMySQLConn :: C.ResourceIO m => + MySQL.ConnectInfo + -- ^ Connection information. + -> (Connection -> m a) + -- ^ Action to be executed that uses the connection. + -> m a +withMySQLConn = withSqlConn . open' + + +-- | Internal function that opens a connection to the MySQL +-- server. +open' :: MySQL.ConnectInfo -> IO Connection +open' ci = do + conn <- MySQL.connect ci + MySQLBase.autocommit conn False -- disable autocommit! + smap <- newIORef $ Map.empty + return Connection + { prepare = prepare' conn + , stmtMap = smap + , insertSql = insertSql' + , close = MySQL.close conn + , migrateSql = migrate' ci + , begin = const $ MySQL.execute_ conn "start transaction" >> return () + , commitC = const $ MySQL.commit conn + , rollbackC = const $ MySQL.rollback conn + , escapeName = pack . escapeDBName + , noLimit = "LIMIT 18446744073709551615" + -- This noLimit is suggested by MySQL's own docs, see + -- + } + +-- | Prepare a query. We don't support prepared statements, but +-- we'll do some client-side preprocessing here. +prepare' :: MySQL.Connection -> Text -> IO Statement +prepare' conn sql = do + let query = MySQL.Query (T.encodeUtf8 sql) + return Statement + { finalize = return () + , reset = return () + , execute = execute' conn query + , withStmt = withStmt' conn query + } + + +-- | SQL code to be executed when inserting an entity. +insertSql' :: DBName -> [DBName] -> Either Text (Text, Text) +insertSql' t cols = Right (doInsert, "SELECT LAST_INSERT_ID()") + where + doInsert = pack $ concat + [ "INSERT INTO " + , escapeDBName t + , "(" + , intercalate "," $ map escapeDBName cols + , ") VALUES(" + , intercalate "," (map (const "?") cols) + , ")" + ] + + +-- | Execute an statement that doesn't return any results. +execute' :: MySQL.Connection -> MySQL.Query -> [PersistValue] -> IO () +execute' conn query vals = MySQL.execute conn query (map P vals) >> return () + + +-- | Execute an statement that does return results. The results +-- are fetched all at once and stored into memory. +withStmt' :: C.ResourceIO m + => MySQL.Connection + -> MySQL.Query + -> [PersistValue] + -> C.Source m [PersistValue] +withStmt' conn query vals = C.sourceIO (liftIO openS ) + (liftIO . closeS) + (liftIO . pullS ) + where + openS = do + -- Execute the query + MySQLBase.query conn =<< MySQL.formatQuery conn query (map P vals) + result <- MySQLBase.storeResult conn + + -- Find out the type of the columns + fields <- MySQLBase.fetchFields result + let getters = [ maybe PersistNull (getGetter (MySQLBase.fieldType f) f . Just) | f <- fields] + + -- Ready to go! + return (result, getters) + + closeS (result, _) = MySQLBase.freeResult result + + pullS (result, getters) = do + row <- MySQLBase.fetchRow result + case row of + [] -> MySQLBase.freeResult result >> return C.IOClosed + _ -> return $ C.IOOpen $ zipWith ($) getters row + + +-- | @newtype@ around 'PersistValue' that supports the +-- 'MySQL.Param' type class. +newtype P = P PersistValue + +instance MySQL.Param P where + render (P (PersistText t)) = MySQL.render t + render (P (PersistByteString bs)) = MySQL.render bs + render (P (PersistInt64 i)) = MySQL.render i + render (P (PersistDouble d)) = MySQL.render d + render (P (PersistBool b)) = MySQL.render b + render (P (PersistDay d)) = MySQL.render d + render (P (PersistTimeOfDay t)) = MySQL.render t + render (P (PersistUTCTime t)) = MySQL.render t + render (P PersistNull) = MySQL.render MySQL.Null + render (P (PersistList _)) = + error "Refusing to serialize a PersistList to a MySQL value" + render (P (PersistMap _)) = + error "Refusing to serialize a PersistMap to a MySQL value" + render (P (PersistObjectId _)) = + error "Refusing to serialize a PersistObjectId to a MySQL value" + + +-- | @Getter a@ is a function that converts an incoming value +-- into a data type @a@. +type Getter a = MySQLBase.Field -> Maybe ByteString -> a + +-- | Helper to construct 'Getter'@s@ using 'MySQL.Result'. +convertPV :: MySQL.Result a => (a -> b) -> Getter b +convertPV f = (f .) . MySQL.convert + +-- | Get the corresponding @'Getter' 'PersistValue'@ depending on +-- the type of the column. +getGetter :: MySQLBase.Type -> Getter PersistValue +-- Bool +getGetter MySQLBase.Tiny = convertPV PersistBool +-- Int64 +getGetter MySQLBase.Int24 = convertPV PersistInt64 +getGetter MySQLBase.Short = convertPV PersistInt64 +getGetter MySQLBase.Long = convertPV PersistInt64 +getGetter MySQLBase.LongLong = convertPV PersistInt64 +-- Double +getGetter MySQLBase.Float = convertPV PersistDouble +getGetter MySQLBase.Double = convertPV PersistDouble +getGetter MySQLBase.Decimal = convertPV PersistDouble +getGetter MySQLBase.NewDecimal = convertPV PersistDouble +-- Text +getGetter MySQLBase.VarChar = convertPV PersistText +getGetter MySQLBase.VarString = convertPV PersistText +getGetter MySQLBase.String = convertPV PersistText +-- ByteString +getGetter MySQLBase.Blob = convertPV PersistByteString +getGetter MySQLBase.TinyBlob = convertPV PersistByteString +getGetter MySQLBase.MediumBlob = convertPV PersistByteString +getGetter MySQLBase.LongBlob = convertPV PersistByteString +-- Time-related +getGetter MySQLBase.Time = convertPV PersistTimeOfDay +getGetter MySQLBase.DateTime = convertPV PersistUTCTime +getGetter MySQLBase.Timestamp = convertPV PersistUTCTime +getGetter MySQLBase.Date = convertPV PersistDay +getGetter MySQLBase.NewDate = convertPV PersistDay +getGetter MySQLBase.Year = convertPV PersistDay +-- Null +getGetter MySQLBase.Null = \_ _ -> PersistNull +-- Controversial conversions +getGetter MySQLBase.Set = convertPV PersistText +getGetter MySQLBase.Enum = convertPV PersistText +-- Unsupported +getGetter other = error $ "MySQL.getGetter: type " ++ + show other ++ " not supported." + + +---------------------------------------------------------------------- + + +-- | Create the migration plan for the given 'PersistEntity' +-- @val@. +migrate' :: PersistEntity val + => MySQL.ConnectInfo + -> [EntityDef] + -> (Text -> IO Statement) + -> val + -> IO (Either [Text] [(Bool, Text)]) +migrate' connectInfo allDefs getter val = do + let name = entityDB $ entityDef val + old <- getColumns connectInfo getter $ entityDef val + let new = second (map udToPair) $ mkColumns allDefs val + case (old, partitionEithers old) of + -- Nothing found, create everything + ([], _) -> do + let addTable = AddTable $ concat + [ "CREATE TABLE " + , escapeDBName name + , "(" + , escapeDBName $ entityID $ entityDef val + , " BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY" + , concatMap (\x -> ',' : showColumn x) $ fst new + , ")" + ] + let uniques = flip concatMap (snd new) $ \(uname, ucols) -> + [ AlterTable name $ + AddUniqueConstraint uname $ + map (findTypeOfColumn allDefs name) ucols ] + let foreigns = do + Column cname _ _ _ (Just (refTblName, _)) <- fst new + return $ AlterColumn name (cname, addReference allDefs refTblName) + return $ Right $ map showAlterDb $ addTable : uniques ++ foreigns + -- No errors and something found, migrate + (_, ([], old')) -> do + let (acs, ats) = getAlters allDefs name new $ partitionEithers old' + acs' = map (AlterColumn name) acs + ats' = map (AlterTable name) ats + return $ Right $ map showAlterDb $ acs' ++ ats' + -- Errors + (_, (errs, _)) -> return $ Left errs + + +-- | Find out the type of a column. +findTypeOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, FieldType) +findTypeOfColumn allDefs name col = + maybe (error $ "Could not find type of column " ++ + show col ++ " on table " ++ show name ++ + " (allDefs = " ++ show allDefs ++ ")") + ((,) col) $ do + entDef <- find ((== name) . entityDB) allDefs + fieldDef <- find ((== col) . fieldDB) (entityFields entDef) + return (fieldType fieldDef) + + +-- | Helper for 'AddRefence' that finds out the 'entityID'. +addReference :: [EntityDef] -> DBName -> AlterColumn +addReference allDefs name = AddReference name id_ + where + id_ = maybe (error $ "Could not find ID of entity " ++ show name + ++ " (allDefs = " ++ show allDefs ++ ")") + id $ do + entDef <- find ((== name) . entityDB) allDefs + return (entityID entDef) + +data AlterColumn = Change Column + | Add Column + | Drop + | Default String + | NoDefault + | Update String + | AddReference DBName DBName + | DropReference DBName + +type AlterColumn' = (DBName, AlterColumn) + +data AlterTable = AddUniqueConstraint DBName [(DBName, FieldType)] + | DropUniqueConstraint DBName + +data AlterDB = AddTable String + | AlterColumn DBName AlterColumn' + | AlterTable DBName AlterTable + + +udToPair :: UniqueDef -> (DBName, [DBName]) +udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud) + + +---------------------------------------------------------------------- + + +-- | Returns all of the 'Column'@s@ in the given table currently +-- in the database. +getColumns :: MySQL.ConnectInfo + -> (Text -> IO Statement) + -> EntityDef + -> IO [Either Text (Either Column (DBName, [DBName]))] +getColumns connectInfo getter def = do + -- Find out all columns. + stmtClmns <- getter "SELECT COLUMN_NAME, \ + \IS_NULLABLE, \ + \DATA_TYPE, \ + \COLUMN_DEFAULT \ + \FROM INFORMATION_SCHEMA.COLUMNS \ + \WHERE TABLE_SCHEMA = ? \ + \AND TABLE_NAME = ? \ + \AND COLUMN_NAME <> ?" + inter <- C.runResourceT $ withStmt stmtClmns vals C.$$ CL.consume + cs <- C.runResourceT $ CL.sourceList inter C.$$ helperClmns -- avoid nested queries + + -- Find out the constraints. + stmtCntrs <- getter "SELECT CONSTRAINT_NAME, \ + \COLUMN_NAME \ + \FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE \ + \WHERE TABLE_SCHEMA = ? \ + \AND TABLE_NAME = ? \ + \AND COLUMN_NAME <> ? \ + \AND REFERENCED_TABLE_SCHEMA IS NULL \ + \ORDER BY CONSTRAINT_NAME, \ + \COLUMN_NAME" + us <- C.runResourceT $ withStmt stmtCntrs vals C.$$ helperCntrs + + -- Return both + return $ cs ++ us + where + vals = [ PersistText $ pack $ MySQL.connectDatabase connectInfo + , PersistText $ unDBName $ entityDB def + , PersistText $ unDBName $ entityID def ] + + helperClmns = CL.mapM getIt C.=$ CL.consume + where + getIt = fmap (either Left (Right . Left)) . + liftIO . + getColumn connectInfo getter (entityDB def) + + helperCntrs = do + let check [PersistText cntrName, PersistText clmnName] = return (cntrName, clmnName) + check other = fail $ "helperCntrs: unexpected " ++ show other + rows <- mapM check =<< CL.consume + return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd))) + $ groupBy ((==) `on` fst) rows + + +-- | Get the information about a column in a table. +getColumn :: MySQL.ConnectInfo + -> (Text -> IO Statement) + -> DBName + -> [PersistValue] + -> IO (Either Text Column) +getColumn connectInfo getter tname [ PersistText cname + , PersistText null_ + , PersistText type' + , default'] = + fmap (either (Left . pack) Right) $ + runErrorT $ do + -- Default value + default_ <- case default' of + PersistNull -> return Nothing + PersistText t -> return (Just t) + _ -> fail $ "Invalid default column: " ++ show default' + + -- Column type + type_ <- parseType type' + + -- Foreign key (if any) + stmt <- lift $ getter "SELECT REFERENCED_TABLE_NAME, \ + \CONSTRAINT_NAME \ + \FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE \ + \WHERE TABLE_SCHEMA = ? \ + \AND TABLE_NAME = ? \ + \AND COLUMN_NAME = ? \ + \AND REFERENCED_TABLE_SCHEMA = ? \ + \ORDER BY CONSTRAINT_NAME, \ + \COLUMN_NAME" + let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo + , PersistText $ unDBName $ tname + , PersistText $ cname + , PersistText $ pack $ MySQL.connectDatabase connectInfo ] + cntrs <- C.runResourceT $ withStmt stmt vars C.$$ CL.consume + ref <- case cntrs of + [] -> return Nothing + [[PersistText tab, PersistText ref]] -> + return $ Just (DBName tab, DBName ref) + _ -> fail "MySQL.getColumn/getRef: never here" + + -- Okay! + return $ Column (DBName cname) (null_ == "YES") type_ default_ ref + +getColumn _ _ _ x = + return $ Left $ pack $ "Invalid result from INFORMATION_SCHEMA: " ++ show x + + +-- | Parse the type of column as returned by MySQL's +-- @INFORMATION_SCHEMA@ tables. +parseType :: Monad m => Text -> m SqlType +parseType "tinyint" = return SqlBool +-- Ints +parseType "int" = return SqlInt32 +parseType "short" = return SqlInt32 +parseType "long" = return SqlInteger +parseType "longlong" = return SqlInteger +parseType "mediumint" = return SqlInt32 +parseType "bigint" = return SqlInteger +-- Double +parseType "float" = return SqlReal +parseType "double" = return SqlReal +parseType "decimal" = return SqlReal +parseType "newdecimal" = return SqlReal +-- Text +parseType "varchar" = return SqlString +parseType "varstring" = return SqlString +parseType "string" = return SqlString +parseType "text" = return SqlString +parseType "tinytext" = return SqlString +parseType "mediumtext" = return SqlString +parseType "longtext" = return SqlString +-- ByteString +parseType "blob" = return SqlBlob +parseType "tinyblob" = return SqlBlob +parseType "mediumblob" = return SqlBlob +parseType "longblob" = return SqlBlob +-- Time-related +parseType "time" = return SqlTime +parseType "datetime" = return SqlDayTime +parseType "timestamp" = return SqlDayTime +parseType "date" = return SqlDay +parseType "newdate" = return SqlDay +parseType "year" = return SqlDay +-- Unsupported +parseType other = fail $ "MySQL.parseType: type " ++ + show other ++ " not supported." + + +---------------------------------------------------------------------- + + +-- | @getAlters allDefs tblName new old@ finds out what needs to +-- be changed from @old@ to become @new@. +getAlters :: [EntityDef] + -> DBName + -> ([Column], [(DBName, [DBName])]) + -> ([Column], [(DBName, [DBName])]) + -> ([AlterColumn'], [AlterTable]) +getAlters allDefs tblName (c1, u1) (c2, u2) = + (getAltersC c1 c2, getAltersU u1 u2) + where + getAltersC [] old = map (\x -> (cName x, Drop)) old + getAltersC (new:news) old = + let (alters, old') = findAlters allDefs new old + in alters ++ getAltersC news old' + + getAltersU [] old = map (DropUniqueConstraint . fst) old + getAltersU ((name, cols):news) old = + case lookup name old of + Nothing -> + AddUniqueConstraint name (map findType cols) : getAltersU news old + Just ocols -> + let old' = filter (\(x, _) -> x /= name) old + in if sort cols == ocols + then getAltersU news old' + else DropUniqueConstraint name + : AddUniqueConstraint name (map findType cols) + : getAltersU news old' + where + findType = findTypeOfColumn allDefs tblName + + +-- | @findAlters newColumn oldColumns@ finds out what needs to be +-- changed in the columns @oldColumns@ for @newColumn@ to be +-- supported. +findAlters :: [EntityDef] -> Column -> [Column] -> ([AlterColumn'], [Column]) +findAlters allDefs col@(Column name isNull type_ def ref) cols = + case filter ((name ==) . cName) cols of + [] -> ( let cnstr = [addReference allDefs tname | Just (tname, _) <- [ref]] + in map ((,) name) (Add col : cnstr) + , cols ) + Column _ isNull' type_' def' ref':_ -> + let -- Foreign key + refDrop = case (ref == ref', ref') of + (False, Just (_, cname)) -> [(name, DropReference cname)] + _ -> [] + refAdd = case (ref == ref', ref) of + (False, Just (tname, _)) -> [(name, addReference allDefs tname)] + _ -> [] + -- Type and nullability + modType | type_ == type_' && isNull == isNull' = [] + | otherwise = [(name, Change col)] + -- Default value + modDef | def == def' = [] + | otherwise = case def of + Nothing -> [(name, NoDefault)] + Just s -> [(name, Default $ T.unpack s)] + in ( refDrop ++ modType ++ modDef ++ refAdd + , filter ((name /=) . cName) cols ) + + +---------------------------------------------------------------------- + + +-- | Prints the part of a @CREATE TABLE@ statement about a given +-- column. +showColumn :: Column -> String +showColumn (Column n nu t def ref) = concat + [ escapeDBName n + , " " + , showSqlType t + , " " + , if nu then "NULL" else "NOT NULL" + , case def of + Nothing -> "" + Just s -> " DEFAULT " ++ T.unpack s + , case ref of + Nothing -> "" + Just (s, _) -> " REFERENCES " ++ escapeDBName s + ] + + +-- | Renders an 'SqlType' in MySQL's format. +showSqlType :: SqlType -> String +showSqlType SqlBlob = "BLOB" +showSqlType SqlBool = "TINYINT(1)" +showSqlType SqlDay = "DATE" +showSqlType SqlDayTime = "DATETIME" +showSqlType SqlInt32 = "INT" +showSqlType SqlInteger = "BIGINT" +showSqlType SqlReal = "DOUBLE PRECISION" +showSqlType SqlString = "VARCHAR(65535)" +showSqlType SqlTime = "TIME" + + +-- | Render an action that must be done on the database. +showAlterDb :: AlterDB -> (Bool, Text) +showAlterDb (AddTable s) = (False, pack s) +showAlterDb (AlterColumn t (c, ac)) = + (isUnsafe ac, pack $ showAlter t (c, ac)) + where + isUnsafe Drop = True + isUnsafe _ = False +showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at) + + +-- | Render an action that must be done on a table. +showAlterTable :: DBName -> AlterTable -> String +showAlterTable table (AddUniqueConstraint cname cols) = concat + [ "ALTER TABLE " + , escapeDBName table + , " ADD CONSTRAINT " + , escapeDBName cname + , " UNIQUE(" + , intercalate "," $ map escapeDBName' cols + , ")" + ] + where + escapeDBName' (name, (FieldType "String")) = escapeDBName name ++ "(200)" + escapeDBName' (name, _ ) = escapeDBName name +showAlterTable table (DropUniqueConstraint cname) = concat + [ "ALTER TABLE " + , escapeDBName table + , " DROP INDEX " + , escapeDBName cname + ] + + +-- | Render an action that must be done on a column. +showAlter :: DBName -> AlterColumn' -> String +showAlter table (n, Change col) = + concat + [ "ALTER TABLE " + , escapeDBName table + , " CHANGE " + , escapeDBName n + , showColumn col + ] +showAlter table (_, Add col) = + concat + [ "ALTER TABLE " + , escapeDBName table + , " ADD COLUMN " + , showColumn col + ] +showAlter table (n, Drop) = + concat + [ "ALTER TABLE " + , escapeDBName table + , " DROP COLUMN " + , escapeDBName n + ] +showAlter table (n, Default s) = + concat + [ "ALTER TABLE " + , escapeDBName table + , " ALTER COLUMN " + , escapeDBName n + , " SET DEFAULT " + , s + ] +showAlter table (n, NoDefault) = + concat + [ "ALTER TABLE " + , escapeDBName table + , " ALTER COLUMN " + , escapeDBName n + , " DROP DEFAULT" + ] +showAlter table (n, Update s) = + concat + [ "UPDATE " + , escapeDBName table + , " SET " + , escapeDBName n + , "=" + , s + , " WHERE " + , escapeDBName n + , " IS NULL" + ] +showAlter table (n, AddReference t2 id2) = concat + [ "ALTER TABLE " + , escapeDBName table + , " ADD CONSTRAINT " + , escapeDBName $ refName table n + , " FOREIGN KEY(" + , escapeDBName n + , ") REFERENCES " + , escapeDBName t2 + , "(" + , escapeDBName id2 + , ")" + ] +showAlter table (_, DropReference cname) = concat + [ "ALTER TABLE " + , escapeDBName table + , " DROP CONSTRAINT " + , escapeDBName cname + ] + +refName :: DBName -> DBName -> DBName +refName (DBName table) (DBName column) = + DBName $ T.concat [table, "_", column, "_fkey"] + + +---------------------------------------------------------------------- + + +-- | Escape a database name to be included on a query. +-- +-- FIXME: Can we do better here? +escapeDBName :: DBName -> String +escapeDBName (DBName s) = T.unpack s + +-- | Information required to connect to a MySQL database +-- using @persistent@'s generic facilities. These values are the +-- same that are given to 'withMySQLPool'. +data MySQLConf = MySQLConf + { myConnInfo :: MySQL.ConnectInfo + -- ^ The connection information. + , myPoolSize :: Int + -- ^ How many connections should be held on the connection pool. + } + + +instance PersistConfig MySQLConf where + type PersistConfigBackend MySQLConf = SqlPersist + + type PersistConfigPool MySQLConf = ConnectionPool + + createPoolConfig (MySQLConf cs size) = createMySQLPool cs size + + runPool _ = runSqlPool + + loadConfig (Object o) = do + database <- o .: "database" + host <- o .: "host" + port <- o .: "port" + user <- o .: "user" + password <- o .: "password" + pool <- o .: "poolsize" + let ci = MySQL.defaultConnectInfo + { MySQL.connectHost = host + , MySQL.connectPort = port + , MySQL.connectUser = user + , MySQL.connectPassword = password + , MySQL.connectDatabase = database + } + return $ MySQLConf ci pool + loadConfig _ = mzero + + applyEnv conf = do + env <- getEnvironment + let maybeEnv old var = maybe old id $ lookup ("MYSQL_" ++ var) env + return conf + { myConnInfo = + case myConnInfo conf of + MySQL.ConnectInfo + { MySQL.connectHost = host + , MySQL.connectPort = port + , MySQL.connectUser = user + , MySQL.connectPassword = password + , MySQL.connectDatabase = database + } -> (myConnInfo conf) + { MySQL.connectHost = maybeEnv host "HOST" + , MySQL.connectPort = read $ maybeEnv (show port) "PORT" + , MySQL.connectUser = maybeEnv user "USER" + , MySQL.connectPassword = maybeEnv password "PASSWORD" + , MySQL.connectDatabase = maybeEnv database "DATABASE" + } + } diff --git a/persistent-mysql/LICENSE b/persistent-mysql/LICENSE new file mode 100644 index 000000000..62e285c7b --- /dev/null +++ b/persistent-mysql/LICENSE @@ -0,0 +1,25 @@ +The following license covers this documentation, and the source code, except +where otherwise indicated. + +Copyright 2012, Felipe Lessa. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS "AS IS" AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO +EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/persistent-mysql/Setup.lhs b/persistent-mysql/Setup.lhs new file mode 100755 index 000000000..06e2708f2 --- /dev/null +++ b/persistent-mysql/Setup.lhs @@ -0,0 +1,7 @@ +#!/usr/bin/env runhaskell + +> module Main where +> import Distribution.Simple + +> main :: IO () +> main = defaultMain diff --git a/persistent-mysql/persistent-mysql.cabal b/persistent-mysql/persistent-mysql.cabal new file mode 100644 index 000000000..678e5dd98 --- /dev/null +++ b/persistent-mysql/persistent-mysql.cabal @@ -0,0 +1,44 @@ +name: persistent-mysql +version: 0.0 +license: BSD3 +license-file: LICENSE +author: Felipe Lessa , Michael Snoyman +maintainer: Felipe Lessa +synopsis: Backend for the persistent library using MySQL database server. +category: Database, Yesod +stability: Stable +cabal-version: >= 1.6 +build-type: Simple +homepage: http://www.yesodweb.com/book/persistent +description: + This package contains a backend for persistent using the + MySQL database server. Internally it uses the @mysql-simple@ + and @mysql@ packages in order to access the database. + . + This package supports only MySQL 5.1 and above. However, it + has been tested only on MySQL 5.5. + . + Known problems: + . + * This package does not support statements inside other + statements. + +library + build-depends: base >= 4 && < 5 + , transformers >= 0.2.1 && < 0.3 + , mysql-simple >= 0.2.2.3 && < 0.3 + , mysql >= 0.1.1.3 && < 0.2 + , persistent >= 0.7 && < 0.8 + , containers >= 0.2 + , bytestring >= 0.9 && < 0.10 + , text >= 0.7 && < 0.12 + , monad-control >= 0.2 && < 0.4 + , time >= 1.1 + , aeson >= 0.5 + , conduit >= 0.2 + exposed-modules: Database.Persist.MySQL + ghc-options: -Wall + +source-repository head + type: git + location: git://github.com/yesodweb/persistent.git diff --git a/persistent-test/DataTypeTest.hs b/persistent-test/DataTypeTest.hs index dc4497abc..0face4163 100644 --- a/persistent-test/DataTypeTest.hs +++ b/persistent-test/DataTypeTest.hs @@ -16,6 +16,9 @@ import Database.Persist.TH #if WITH_POSTGRESQL import Database.Persist.Postgresql #endif +#if WITH_MYSQL +import Database.Persist.MySQL +#endif import Data.Char (generalCategory, GeneralCategory(..)) import Data.Text (Text) import qualified Data.Text as T diff --git a/persistent-test/PersistentTest.hs b/persistent-test/PersistentTest.hs index cb0469d1c..c9f7248ca 100644 --- a/persistent-test/PersistentTest.hs +++ b/persistent-test/PersistentTest.hs @@ -65,6 +65,9 @@ import Control.Monad.Trans.Resource (ResourceIO) #if WITH_POSTGRESQL import Database.Persist.Postgresql #endif +#if WITH_MYSQL +import Database.Persist.MySQL +#endif #endif @@ -222,6 +225,14 @@ runConn f = do _<-withSqlitePool sqlite_database 1 $ runSqlPool f #if WITH_POSTGRESQL _<-withPostgresqlPool "host=localhost port=5432 user=test dbname=test password=test" 1 $ runSqlPool f +#endif +#if WITH_MYSQL + _ <- withMySQLPool defaultConnectInfo + { connectHost = "localhost" + , connectUser = "test" + , connectPassword = "test" + , connectDatabase = "test" + } 1 $ runSqlPool f #endif return () @@ -611,7 +622,7 @@ specs = describe "persistent" $ do -- limit ps2 <- selectList [] [LimitTo 1] ps2 @== [(Entity key25 p25)] - -- offset -- FAILS! + -- offset ps3 <- selectList [] [OffsetBy 1] ps3 @== [(Entity key26 p26)] -- limit & offset diff --git a/persistent-test/RenameTest.hs b/persistent-test/RenameTest.hs index 8921aac07..cafca36d2 100644 --- a/persistent-test/RenameTest.hs +++ b/persistent-test/RenameTest.hs @@ -18,6 +18,9 @@ import Database.Persist.GenericSql.Raw #if WITH_POSTGRESQL import Database.Persist.Postgresql #endif +#if WITH_MYSQL +import Database.Persist.MySQL +#endif import qualified Data.Conduit as C import qualified Data.Conduit.List as CL import qualified Data.Map as Map @@ -44,6 +47,14 @@ runConn2 f = do _ <- withSqlitePool ":memory:" 1 $ runSqlPool f #if WITH_POSTGRESQL _ <- withPostgresqlPool "host=localhost port=5432 user=test dbname=test password=test" 1 $ runSqlPool f +#endif +#if WITH_MYSQL + _ <- withMySQLPool defaultConnectInfo + { connectHost = "localhost" + , connectUser = "test" + , connectPassword = "test" + , connectDatabase = "test" + } 1 $ runSqlPool f #endif return () diff --git a/persistent-test/persistent-mysql b/persistent-test/persistent-mysql new file mode 120000 index 000000000..1bb2c8d21 --- /dev/null +++ b/persistent-test/persistent-mysql @@ -0,0 +1 @@ +../persistent-mysql/ \ No newline at end of file diff --git a/persistent-test/persistent-test.cabal b/persistent-test/persistent-test.cabal index eee7b831f..0bcb975b3 100644 --- a/persistent-test/persistent-test.cabal +++ b/persistent-test/persistent-test.cabal @@ -24,6 +24,10 @@ Flag postgresql Description: test postgresql. default is to test just sqlite. Default: False +Flag mysql + Description: test MySQL + Default: False + library extra-libraries: sqlite3 @@ -50,6 +54,7 @@ library Database.Persist.Sqlite Database.Sqlite Database.Persist.Postgresql + Database.Persist.MySQL Database.Persist.MongoDB @@ -67,7 +72,7 @@ library , monad-control , containers , bytestring - , conduit + , conduit >= 0.2 , time >= 1.2 , random == 1.* , QuickCheck == 2.4.* @@ -79,12 +84,16 @@ library , postgresql-simple >= 0.0 && < 1.0 , postgresql-libpq >= 0.6 + -- MySQL dependencies + , mysql-simple >= 0.2.2.3 && < 0.3 + , mysql >= 0.1.1.3 && < 0.2 + -- mongoDB dependencies , mongoDB == 1.2.* , cereal , compact-string-fix , bson - hs-source-dirs: ., persistent, persistent-template, persistent-sqlite, persistent-postgresql, persistent-mongoDB + hs-source-dirs: ., persistent, persistent-template, persistent-sqlite, persistent-postgresql, persistent-mysql, persistent-mongoDB ghc-options: -Wall @@ -95,6 +104,9 @@ library -- else -- if flag(postgresql) -- cpp-options: -DWITH_POSTGRESQL +-- else +-- if flag(mysql) +-- cpp-options: -DWITH_MYSQL test-suite test type: exitcode-stdio-1.0 @@ -112,6 +124,9 @@ test-suite test -- else -- if flag(postgresql) -- cpp-options: -DWITH_POSTGRESQL +-- else +-- if flag(mysql) +-- cpp-options: -DWITH_MYSQL source-repository head