Skip to content

Commit

Permalink
Merge branch 'main' into change-/groups/{groupId}/users-route
Browse files Browse the repository at this point in the history
  • Loading branch information
zsiggg committed Feb 18, 2024
2 parents f708b88 + a603c40 commit 1eee2bc
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 124 deletions.
32 changes: 21 additions & 11 deletions model/stories.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,20 @@ func GetStoryByID(db *gorm.DB, id int) (Story, error) {
return story, nil
}

func CreateStory(db *gorm.DB, story *Story) error {
err := db.
func (s *Story) create(tx *gorm.DB) *gorm.DB {
return tx.
Preload(clause.Associations).
Create(story).
Create(s).
// Get associated Author. See
// https://github.com/go-gorm/gen/issues/618 on why
// a separate .First() is needed.
First(story).
Error
First(s)
}

func CreateStory(db *gorm.DB, story *Story) error {
err := db.Transaction(func(tx *gorm.DB) error {
return story.create(tx).Error
})
if err != nil {
return database.HandleDBError(err, "story")
}
Expand Down Expand Up @@ -109,14 +114,19 @@ func UpdateStory(db *gorm.DB, storyID int, newStory *Story) error {
return nil
}

func DeleteStory(db *gorm.DB, storyID int) (Story, error) {
var story Story
err := db.
func (s *Story) delete(tx *gorm.DB, storyID uint) *gorm.DB {
return tx.
Preload(clause.Associations).
Where("id = ?", storyID).
First(&story). // store the value to be returned
Delete(&story).
Error
First(s). // store the value to be returned
Delete(s)
}

func DeleteStory(db *gorm.DB, storyID int) (Story, error) {
var story Story
err := db.Transaction(func(tx *gorm.DB) error {
return story.delete(tx, uint(storyID)).Error
})
if err != nil {
return story, database.HandleDBError(err, "story")
}
Expand Down
24 changes: 22 additions & 2 deletions model/stories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
userenums "github.com/source-academy/stories-backend/internal/enums/users"
"github.com/source-academy/stories-backend/internal/testutils"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)

// FIXME: Coupling with the other operations in the users database
Expand Down Expand Up @@ -40,6 +41,25 @@ var (
}
)

func TestCreate(t *testing.T) {
t.Run("", func(t *testing.T) {
db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath)
defer cleanUp(t)

// Any number is fine because the statement is not executed,
// thus removing the coupling with an actual author having to be
// created prior.
story := &Story{
AuthorID: 1,
Content: "The quick brown test content 5678.",
}
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return story.create(tx)
})
assert.Contains(t, sql, "The quick brown test content 5678.", "Should contain the story content")
})
}

func TestCreateStory(t *testing.T) {
t.Run("should increase the total story count", func(t *testing.T) {
db, cleanUp := testutils.SetupDBConnection(t, dbConfig, migrationPath)
Expand Down Expand Up @@ -221,8 +241,8 @@ func TestStoryDB(t *testing.T) {
}
_ = CreateGroup(db, &group)

err := db.Exec(`INSERT INTO "stories"
("created_at","updated_at","deleted_at","author_id","group_id","title","content","pin_order")
err := db.Exec(`INSERT INTO "stories"
("created_at","updated_at","deleted_at","author_id","group_id","title","content","pin_order")
VALUES ('2023-08-08 22:17:28.085','2023-08-08 22:17:28.085',NULL,NULL,NULL,'','# Hi, This is a test story.',NULL)`).
Error
var pgerr *pgconn.PgError
Expand Down
36 changes: 21 additions & 15 deletions model/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,24 @@ func GetUserByID(db *gorm.DB, id int) (User, error) {
return user, err
}

func CreateUser(db *gorm.DB, user *User) error {
func (u *User) create(tx *gorm.DB) *gorm.DB {
// TODO: If user already exists, but is soft-deleted, undelete the user
err := db.Create(user).Error
return tx.Create(u)
}

func CreateUser(db *gorm.DB, user *User) error {
err := db.Transaction(func(tx *gorm.DB) error {
return user.create(tx).Error
})
if err != nil {
return database.HandleDBError(err, "user")
}
return nil
}

func CreateUsers(db *gorm.DB, users *[]*User) (int64, error) {
// TODO: Use users.create() instead
// Blocked by `RowsAffected` not being accessible.
tx := db.Create(users)
rowCount := tx.RowsAffected
if err := tx.Error; err != nil {
Expand All @@ -49,21 +57,19 @@ func CreateUsers(db *gorm.DB, users *[]*User) (int64, error) {
return rowCount, nil
}

func DeleteUser(db *gorm.DB, userID int) (User, error) {
var user User
err := db.
Model(&user).
func (u *User) delete(tx *gorm.DB, userID uint) *gorm.DB {
return tx.
Model(u).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return user, database.HandleDBError(err, "user")
}
First(u). // store the value to be returned
Delete(u)
}

err = db.
Model(&user).
Delete(&user).
Error
func DeleteUser(db *gorm.DB, userID int) (User, error) {
var user User
err := db.Transaction(func(tx *gorm.DB) error {
return user.delete(tx, uint(userID)).Error
})
if err != nil {
return user, database.HandleDBError(err, "user")
}
Expand Down
96 changes: 21 additions & 75 deletions scripts/create_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,47 @@ import (
"errors"
"fmt"

"github.com/source-academy/stories-backend/internal/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

func connectAnonDB(conf config.DatabaseConfig) (*gorm.DB, error) {
conf.DatabaseName = ""
dsn := conf.ToDataSourceName()
return connectDBHelper(dsn)
}

func connectDB(conf config.DatabaseConfig) (*gorm.DB, error) {
dsn := conf.ToDataSourceName()
return connectDBHelper(dsn)
}

func connectDBHelper(dsn string) (*gorm.DB, error) {
driver := postgres.Open(dsn)

db, err := gorm.Open(driver, &gorm.Config{})
if err != nil {
return nil, err
}

dbName, err := getConnectedDBName(db)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Connected to database", dbName+".")

return db, nil
}

func closeDBConnection(d *gorm.DB) {
db, err := d.DB()
if err != nil {
panic(err)
}

dbName, err := getConnectedDBName(d)
if err != nil {
panic(err)
}
fmt.Println(blueSandwich, "Closing connection with database", dbName+".")

if err := db.Close(); err != nil {
panic(err)
}
}

func createDB(db *gorm.DB, dbconf *config.DatabaseConfig) error {
if dbconf.DatabaseName == "" {
func createDB(db *gorm.DB, dbName string) error {
if dbName == "" {
return errors.New("Failed to create database: no database name provided.")
}

// check if db exists
fmt.Println(yellowChevron, "Checking if database", dbconf.DatabaseName, "exists.")
result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbconf.DatabaseName)
fmt.Println(yellowChevron, "Checking if database", dbName, "exists.")
result := db.Raw("SELECT * FROM pg_database WHERE datname = ?", dbName)
if result.Error != nil {
return result.Error
}

// if not exists create it
rec := make(map[string]interface{})
if result.Find(rec); len(rec) == 0 {
fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "does not exist. Creating...")

create_command := fmt.Sprintf("CREATE DATABASE %s", dbconf.DatabaseName)
result := db.Exec(create_command)

if result.Error != nil {
return result.Error
}
if result.Find(rec); len(rec) != 0 {
fmt.Println(greenTick, "Database", dbName, "already exists.")
return nil
}

fmt.Println(yellowChevron, "Database", dbconf.DatabaseName, "exists.")
fmt.Println(yellowChevron, "Database", dbName, "does not exist. Creating...")
create_command := fmt.Sprintf("CREATE DATABASE %s", dbName)
err := db.Exec(create_command).Error
if err != nil {
return err
}

fmt.Println(greenTick, "Created database:", dbName)
return nil
}

func dropDB(db *gorm.DB, dbconf *config.DatabaseConfig) error {
drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbconf.DatabaseName)
result := db.Exec(drop_command)
if result.Error != nil {
return result.Error
}

return nil
func dropDB(db *gorm.DB, dbName string) error {
drop_command := fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName)
err := db.Exec(drop_command).Error
return err
}

func getConnectedDBName(db *gorm.DB) (string, error) {
func getDBName(db *gorm.DB) (string, error) {
var dbName string
result := db.Raw("SELECT current_database();").Scan(&dbName)
if result.Error != nil {
return "", result.Error
}
return dbName, nil
err := db.Raw("SELECT current_database();").Scan(&dbName).Error
return dbName, err
}
Loading

0 comments on commit 1eee2bc

Please sign in to comment.