From f369819c41d7cc217434478b59c5ef2fbf61bb81 Mon Sep 17 00:00:00 2001 From: Richard Dominick <34370238+RichDom2185@users.noreply.github.com> Date: Mon, 12 Feb 2024 21:55:20 +0800 Subject: [PATCH 1/2] Refactor DB scripts (#106) * Access error directly via variable * Remove unnecesary 'if' clause The default value of a string variable is already an empty string, and there was no additional error handling done. * Move printing logic to main script Allows for better abstraction and improves reusability. * Refactor to use common DB "connect" helper Reduces code duplication. * Refactor to use common DB "close" helper * Rename `d` to `dbConn` Makes the name more descriptive. * Add comment * Replace string with constants Prevents bugs. * Rename `getConnectedDBName` to `getDBName` * Change createDB signature Hides implementation details of the config object as the function only needs to be aware of the database name to create, not how it is derived from. * Change dropDB signature Similar to createDB, this improves abstraction. * Refactor script setup to separate function Improves single responsibility. * Simplify dropDB implementation * Refactor createDB implementation * Move happy path to outside if condition * Simplify error handling * Update logging behavior * Update log colors * Remove unnecessary code --- scripts/create_db.go | 96 ++++++++++---------------------------------- scripts/db.go | 69 +++++++++++++++++++++---------- 2 files changed, 69 insertions(+), 96 deletions(-) diff --git a/scripts/create_db.go b/scripts/create_db.go index 4cb0612..6be35ef 100644 --- a/scripts/create_db.go +++ b/scripts/create_db.go @@ -4,101 +4,47 @@ import ( "errors" "fmt" - "github.com/source-academy/stories-backend/internal/config" - "gorm.io/driver/postgres" "gorm.io/gorm" ) -func connectAnonDB(conf config.DatabaseConfig) (*gorm.DB, error) { - conf.DatabaseName = "" - dsn := conf.ToDataSourceName() - return connectDBHelper(dsn) -} - -func connectDB(conf config.DatabaseConfig) (*gorm.DB, error) { - dsn := conf.ToDataSourceName() - return connectDBHelper(dsn) -} - -func connectDBHelper(dsn string) (*gorm.DB, error) { - driver := postgres.Open(dsn) - - db, err := gorm.Open(driver, &gorm.Config{}) - if err != nil { - return nil, err - } - - dbName, err := getConnectedDBName(db) - if err != nil { - panic(err) - } - fmt.Println(blueSandwich, "Connected to database", dbName+".") - - return db, nil -} - -func closeDBConnection(d *gorm.DB) { - db, err := d.DB() - if err != nil { - panic(err) - } - - dbName, err := getConnectedDBName(d) - if err != nil { - panic(err) - } - fmt.Println(blueSandwich, "Closing connection with database", dbName+".") - - if err := db.Close(); err != nil { - panic(err) - } -} - -func createDB(db *gorm.DB, dbconf *config.DatabaseConfig) error { - if dbconf.DatabaseName == "" { +func createDB(db *gorm.DB, dbName string) error { + if dbName == "" { return errors.New("Failed to create database: no database name provided.") } // check if db exists - fmt.Println(yellowChevron, "Checking if database", dbconf.DatabaseName, "exists.") - result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbconf.DatabaseName) + fmt.Println(yellowChevron, "Checking if database", dbName, "exists.") + result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbName) if result.Error != nil { return result.Error } // if not exists create it rec := make(map[string]interface{}) - if result.Find(rec); len(rec) == 0 { - fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "does not exist. Creating...") - - create_command := fmt.Sprintf("CREATE DATABASE %s", dbconf.DatabaseName) - result := db.Exec(create_command) - - if result.Error != nil { - return result.Error - } + if result.Find(rec); len(rec) != 0 { + fmt.Println(greenTick, "Database", dbName, "already exists.") + return nil } - fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "exists.") + fmt.Println(yellowChevron, "Database", dbName, "does not exist. Creating...") + create_command := fmt.Sprintf("CREATE DATABASE %s", dbName) + err := db.Exec(create_command).Error + if err != nil { + return err + } + fmt.Println(greenTick, "Created database:", dbName) return nil } -func dropDB(db *gorm.DB, dbconf *config.DatabaseConfig) error { - drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbconf.DatabaseName) - result := db.Exec(drop_command) - if result.Error != nil { - return result.Error - } - - return nil +func dropDB(db *gorm.DB, dbName string) error { + drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName) + err := db.Exec(drop_command).Error + return err } -func getConnectedDBName(db *gorm.DB) (string, error) { +func getDBName(db *gorm.DB) (string, error) { var dbName string - result := db.Raw("SELECT current_database();").Scan(&dbName) - if result.Error != nil { - return "", result.Error - } - return dbName, nil + err := db.Raw("SELECT current_database();").Scan(&dbName).Error + return dbName, err } diff --git a/scripts/db.go b/scripts/db.go index 0a85f72..087afa0 100644 --- a/scripts/db.go +++ b/scripts/db.go @@ -10,12 +10,18 @@ import ( migrate "github.com/rubenv/sql-migrate" "github.com/sirupsen/logrus" "github.com/source-academy/stories-backend/internal/config" - "gorm.io/gorm" + "github.com/source-academy/stories-backend/internal/database" ) const ( defaultMaxMigrateSteps = 0 // no limit defaultMaxRollbackSteps = 1 + + dropCmd = "drop" + createCmd = "create" + migrateCmd = "migrate" + rollbackCmd = "rollback" + statusCmd = "status" ) var ( @@ -28,7 +34,7 @@ var ( }) ) -func main() { +func setupScript() (string, *config.DatabaseConfig) { // Load configuration conf, err := config.LoadFromEnvironment() if err != nil { @@ -36,59 +42,80 @@ func main() { panic(err) } - var connector func(config.DatabaseConfig) (*gorm.DB, error) + targetDBName := conf.Database.DatabaseName // Check for command line arguments flag.Parse() switch flag.Arg(0) { - case "drop", "create": - connector = connectAnonDB - case "migrate", "rollback", "status": - connector = connectDB + case dropCmd, createCmd: + // We need to connect anonymously in order + // to drop or create the database. + conf.Database.DatabaseName = "" + case migrateCmd, rollbackCmd, statusCmd: + // Do nothing default: logrus.Errorln("Invalid command") + return targetDBName, nil + } + + return targetDBName, conf.Database +} + +func main() { + targetDBName, dbConfig := setupScript() + if dbConfig == nil { + // Invalid configuration return } // Connect to the database - d, err := connector(*conf.Database) + dbConn, err := database.Connect(dbConfig) if err != nil { logrus.Errorln(err) panic(err) } - defer closeDBConnection(d) + // Remember to close the connection + defer (func() { + fmt.Println(blueSandwich, "Closing connection...") + database.Close(dbConn) + })() + + dbName, err := getDBName(dbConn) + if err != nil { + panic(err) + } + fmt.Println(blueSandwich, "Connected to database", dbName+".") switch flag.Arg(0) { - case "drop": - err := dropDB(d, conf.Database) + case dropCmd: + err := dropDB(dbConn, targetDBName) if err != nil { logrus.Errorln(err) panic(err) } - fmt.Println(greenTick, "Dropped database:", conf.Database.DatabaseName) - case "create": - err := createDB(d, conf.Database) + fmt.Println(greenTick, "Dropped database:", targetDBName) + case createCmd: + err := createDB(dbConn, targetDBName) if err != nil { logrus.Errorln(err) panic(err) } - fmt.Println(greenTick, "Created database:", conf.Database.DatabaseName) - case "migrate": - db, err := d.DB() + case migrateCmd: + db, err := dbConn.DB() if err != nil { logrus.Errorln(err) panic(err) } migrateDB(db) - case "rollback": - db, err := d.DB() + case rollbackCmd: + db, err := dbConn.DB() if err != nil { logrus.Errorln(err) panic(err) } rollbackDB(db) - case "status": - db, err := d.DB() + case statusCmd: + db, err := dbConn.DB() if err != nil { logrus.Errorln(err) panic(err) From a603c405b70eef7cb7f677478d0a6de15c10e3b3 Mon Sep 17 00:00:00 2001 From: Richard Dominick <34370238+RichDom2185@users.noreply.github.com> Date: Mon, 12 Feb 2024 21:58:44 +0800 Subject: [PATCH 2/2] Refactor DB setup Part 1 (#105) * Separate CreateStory DB operation to own function Improves testability for unit testing. * Add sample create DB test * Separate DeleteStory DB operation to own function * Separate CreateUser DB operation to own function * Add TODO * Separate DeleteUser DB operation to own function * Fix pointer errors --- model/stories.go | 32 +++++++++++++++++++++----------- model/stories_test.go | 24 ++++++++++++++++++++++-- model/users.go | 29 +++++++++++++++++++++-------- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/model/stories.go b/model/stories.go index a3ccc07..1f1dea6 100644 --- a/model/stories.go +++ b/model/stories.go @@ -47,15 +47,20 @@ func GetStoryByID(db *gorm.DB, id int) (Story, error) { return story, nil } -func CreateStory(db *gorm.DB, story *Story) error { - err := db. +func (s *Story) create(tx *gorm.DB) *gorm.DB { + return tx. Preload(clause.Associations). - Create(story). + Create(s). // Get associated Author. See // https://github.com/go-gorm/gen/issues/618 on why // a separate .First() is needed. - First(story). - Error + First(s) +} + +func CreateStory(db *gorm.DB, story *Story) error { + err := db.Transaction(func(tx *gorm.DB) error { + return story.create(tx).Error + }) if err != nil { return database.HandleDBError(err, "story") } @@ -109,14 +114,19 @@ func UpdateStory(db *gorm.DB, storyID int, newStory *Story) error { return nil } -func DeleteStory(db *gorm.DB, storyID int) (Story, error) { - var story Story - err := db. +func (s *Story) delete(tx *gorm.DB, storyID uint) *gorm.DB { + return tx. Preload(clause.Associations). Where("id = ?", storyID). - First(&story). // store the value to be returned - Delete(&story). - Error + First(s). // store the value to be returned + Delete(s) +} + +func DeleteStory(db *gorm.DB, storyID int) (Story, error) { + var story Story + err := db.Transaction(func(tx *gorm.DB) error { + return story.delete(tx, uint(storyID)).Error + }) if err != nil { return story, database.HandleDBError(err, "story") } diff --git a/model/stories_test.go b/model/stories_test.go index 3847fe7..8c9a0f7 100644 --- a/model/stories_test.go +++ b/model/stories_test.go @@ -9,6 +9,7 @@ import ( userenums "github.com/source-academy/stories-backend/internal/enums/users" "github.com/source-academy/stories-backend/internal/testutils" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) // FIXME: Coupling with the other operations in the users database @@ -40,6 +41,25 @@ var ( } ) +func TestCreate(t *testing.T) { + t.Run("", func(t *testing.T) { + db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath) + defer cleanUp(t) + + // Any number is fine because the statement is not executed, + // thus removing the coupling with an actual author having to be + // created prior. + story := &Story{ + AuthorID: 1, + Content: "The quick brown test content 5678.", + } + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return story.create(tx) + }) + assert.Contains(t, sql, "The quick brown test content 5678.", "Should contain the story content") + }) +} + func TestCreateStory(t *testing.T) { t.Run("should increase the total story count", func(t *testing.T) { db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath) @@ -221,8 +241,8 @@ func TestStoryDB(t *testing.T) { } _ = CreateGroup(db, &group) - err := db.Exec(`INSERT INTO "stories" - ("created_at","updated_at","deleted_at","author_id","group_id","title","content","pin_order") + err := db.Exec(`INSERT INTO "stories" + ("created_at","updated_at","deleted_at","author_id","group_id","title","content","pin_order") VALUES ('2023-08-08 22:17:28.085','2023-08-08 22:17:28.085',NULL,NULL,NULL,'','# Hi, This is a test story.',NULL)`). Error var pgerr *pgconn.PgError diff --git a/model/users.go b/model/users.go index e2986ba..9265303 100644 --- a/model/users.go +++ b/model/users.go @@ -31,9 +31,15 @@ func GetUserByID(db *gorm.DB, id int) (User, error) { return user, err } -func CreateUser(db *gorm.DB, user *User) error { +func (u *User) create(tx *gorm.DB) *gorm.DB { // TODO: If user already exists, but is soft-deleted, undelete the user - err := db.Create(user).Error + return tx.Create(u) +} + +func CreateUser(db *gorm.DB, user *User) error { + err := db.Transaction(func(tx *gorm.DB) error { + return user.create(tx).Error + }) if err != nil { return database.HandleDBError(err, "user") } @@ -41,6 +47,8 @@ func CreateUser(db *gorm.DB, user *User) error { } func CreateUsers(db *gorm.DB, users *[]*User) (int64, error) { + // TODO: Use users.create() instead + // Blocked by `RowsAffected` not being accessible. tx := db.Create(users) rowCount := tx.RowsAffected if err := tx.Error; err != nil { @@ -49,14 +57,19 @@ func CreateUsers(db *gorm.DB, users *[]*User) (int64, error) { return rowCount, nil } +func (u *User) delete(tx *gorm.DB, userID uint) *gorm.DB { + return tx. + Model(u). + Where("id = ?", userID). + First(u). // store the value to be returned + Delete(u) +} + func DeleteUser(db *gorm.DB, userID int) (User, error) { var user User - err := db. - Model(&user). - Where("id = ?", userID). - First(&user). // store the value to be returned - Delete(&user). - Error + err := db.Transaction(func(tx *gorm.DB) error { + return user.delete(tx, uint(userID)).Error + }) if err != nil { return user, database.HandleDBError(err, "user") }