diff --git a/scripts/create_db.go b/scripts/create_db.go index 4cb0612..6be35ef 100644 --- a/scripts/create_db.go +++ b/scripts/create_db.go @@ -4,101 +4,47 @@ import ( "errors" "fmt" - "github.com/source-academy/stories-backend/internal/config" - "gorm.io/driver/postgres" "gorm.io/gorm" ) -func connectAnonDB(conf config.DatabaseConfig) (*gorm.DB, error) { - conf.DatabaseName = "" - dsn := conf.ToDataSourceName() - return connectDBHelper(dsn) -} - -func connectDB(conf config.DatabaseConfig) (*gorm.DB, error) { - dsn := conf.ToDataSourceName() - return connectDBHelper(dsn) -} - -func connectDBHelper(dsn string) (*gorm.DB, error) { - driver := postgres.Open(dsn) - - db, err := gorm.Open(driver, &gorm.Config{}) - if err != nil { - return nil, err - } - - dbName, err := getConnectedDBName(db) - if err != nil { - panic(err) - } - fmt.Println(blueSandwich, "Connected to database", dbName+".") - - return db, nil -} - -func closeDBConnection(d *gorm.DB) { - db, err := d.DB() - if err != nil { - panic(err) - } - - dbName, err := getConnectedDBName(d) - if err != nil { - panic(err) - } - fmt.Println(blueSandwich, "Closing connection with database", dbName+".") - - if err := db.Close(); err != nil { - panic(err) - } -} - -func createDB(db *gorm.DB, dbconf *config.DatabaseConfig) error { - if dbconf.DatabaseName == "" { +func createDB(db *gorm.DB, dbName string) error { + if dbName == "" { return errors.New("Failed to create database: no database name provided.") } // check if db exists - fmt.Println(yellowChevron, "Checking if database", dbconf.DatabaseName, "exists.") - result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbconf.DatabaseName) + fmt.Println(yellowChevron, "Checking if database", dbName, "exists.") + result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbName) if result.Error != nil { return result.Error } // if not exists create it rec := make(map[string]interface{}) - if result.Find(rec); len(rec) == 0 { - fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "does not exist. Creating...") - - create_command := fmt.Sprintf("CREATE DATABASE %s", dbconf.DatabaseName) - result := db.Exec(create_command) - - if result.Error != nil { - return result.Error - } + if result.Find(rec); len(rec) != 0 { + fmt.Println(greenTick, "Database", dbName, "already exists.") + return nil } - fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "exists.") + fmt.Println(yellowChevron, "Database", dbName, "does not exist. Creating...") + create_command := fmt.Sprintf("CREATE DATABASE %s", dbName) + err := db.Exec(create_command).Error + if err != nil { + return err + } + fmt.Println(greenTick, "Created database:", dbName) return nil } -func dropDB(db *gorm.DB, dbconf *config.DatabaseConfig) error { - drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbconf.DatabaseName) - result := db.Exec(drop_command) - if result.Error != nil { - return result.Error - } - - return nil +func dropDB(db *gorm.DB, dbName string) error { + drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName) + err := db.Exec(drop_command).Error + return err } -func getConnectedDBName(db *gorm.DB) (string, error) { +func getDBName(db *gorm.DB) (string, error) { var dbName string - result := db.Raw("SELECT current_database();").Scan(&dbName) - if result.Error != nil { - return "", result.Error - } - return dbName, nil + err := db.Raw("SELECT current_database();").Scan(&dbName).Error + return dbName, err } diff --git a/scripts/db.go b/scripts/db.go index 0a85f72..087afa0 100644 --- a/scripts/db.go +++ b/scripts/db.go @@ -10,12 +10,18 @@ import ( migrate "github.com/rubenv/sql-migrate" "github.com/sirupsen/logrus" "github.com/source-academy/stories-backend/internal/config" - "gorm.io/gorm" + "github.com/source-academy/stories-backend/internal/database" ) const ( defaultMaxMigrateSteps = 0 // no limit defaultMaxRollbackSteps = 1 + + dropCmd = "drop" + createCmd = "create" + migrateCmd = "migrate" + rollbackCmd = "rollback" + statusCmd = "status" ) var ( @@ -28,7 +34,7 @@ var ( }) ) -func main() { +func setupScript() (string, *config.DatabaseConfig) { // Load configuration conf, err := config.LoadFromEnvironment() if err != nil { @@ -36,59 +42,80 @@ func main() { panic(err) } - var connector func(config.DatabaseConfig) (*gorm.DB, error) + targetDBName := conf.Database.DatabaseName // Check for command line arguments flag.Parse() switch flag.Arg(0) { - case "drop", "create": - connector = connectAnonDB - case "migrate", "rollback", "status": - connector = connectDB + case dropCmd, createCmd: + // We need to connect anonymously in order + // to drop or create the database. + conf.Database.DatabaseName = "" + case migrateCmd, rollbackCmd, statusCmd: + // Do nothing default: logrus.Errorln("Invalid command") + return targetDBName, nil + } + + return targetDBName, conf.Database +} + +func main() { + targetDBName, dbConfig := setupScript() + if dbConfig == nil { + // Invalid configuration return } // Connect to the database - d, err := connector(*conf.Database) + dbConn, err := database.Connect(dbConfig) if err != nil { logrus.Errorln(err) panic(err) } - defer closeDBConnection(d) + // Remember to close the connection + defer (func() { + fmt.Println(blueSandwich, "Closing connection...") + database.Close(dbConn) + })() + + dbName, err := getDBName(dbConn) + if err != nil { + panic(err) + } + fmt.Println(blueSandwich, "Connected to database", dbName+".") switch flag.Arg(0) { - case "drop": - err := dropDB(d, conf.Database) + case dropCmd: + err := dropDB(dbConn, targetDBName) if err != nil { logrus.Errorln(err) panic(err) } - fmt.Println(greenTick, "Dropped database:", conf.Database.DatabaseName) - case "create": - err := createDB(d, conf.Database) + fmt.Println(greenTick, "Dropped database:", targetDBName) + case createCmd: + err := createDB(dbConn, targetDBName) if err != nil { logrus.Errorln(err) panic(err) } - fmt.Println(greenTick, "Created database:", conf.Database.DatabaseName) - case "migrate": - db, err := d.DB() + case migrateCmd: + db, err := dbConn.DB() if err != nil { logrus.Errorln(err) panic(err) } migrateDB(db) - case "rollback": - db, err := d.DB() + case rollbackCmd: + db, err := dbConn.DB() if err != nil { logrus.Errorln(err) panic(err) } rollbackDB(db) - case "status": - db, err := d.DB() + case statusCmd: + db, err := dbConn.DB() if err != nil { logrus.Errorln(err) panic(err)