Skip to content

Commit

Permalink
feat: support config multiple databases (#123)
Browse files Browse the repository at this point in the history
* feat: support config multiple databases
  • Loading branch information
qloog authored Dec 24, 2023
1 parent 4ad255e commit aa3b835
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v1.8.2
- feat: support PostgreSQL
- feat: support config multiple databases

## v1.8.1
- fix: GitHub workflow badge URL
Expand Down
21 changes: 11 additions & 10 deletions config/docker/database.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
Driver: mysql # 驱动名称,目前支持 mysql,postgres,默认: mysql
Name: eagle # 数据库名称
Addr: db:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306
UserName: root
Password: root
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms
default:
Driver: mysql # 驱动名称,目前支持 mysql,postgres,默认: mysql
Name: eagle # 数据库名称
Addr: db:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306
UserName: root
Password: root
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms
32 changes: 22 additions & 10 deletions config/local/database.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
Driver: mysql # 驱动名称,目前支持: mysql,postgres,默认: mysql
Name: eagle # 数据库名称
Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306, pg:5432
UserName: root
Password: 123456
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms
default:
Driver: mysql # 驱动名称,目前支持: mysql,postgres,默认: mysql
Name: eagle # 数据库名称
Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306, pg:5432
UserName: root
Password: 123456
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms
user:
Driver: mysql # 驱动名称,目前支持: mysql,postgres,默认: mysql
Name: eagle # 数据库名称
Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306, pg:5432
UserName: root
Password: 123456
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms
3 changes: 1 addition & 2 deletions internal/handler/v1/user/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package user

import (
"errors"
"time"

"github.com/gin-gonic/gin"
"github.com/spf13/cast"
Expand Down Expand Up @@ -44,7 +43,7 @@ func Get(c *gin.Context) {
return
}

time.Sleep(5 * time.Second)
//time.Sleep(5 * time.Second)

response.Success(c, u)
}
36 changes: 17 additions & 19 deletions internal/model/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,33 @@ import (

"gorm.io/gorm"

"github.com/go-eagle/eagle/pkg/config"
"github.com/go-eagle/eagle/pkg/storage/orm"
)

// DB 数据库全局变量
var DB *gorm.DB
const (
// DefaultDatabase default database
DefaultDatabase = "default"
// UserDatabase user database
UserDatabase = "user"
)

// Init 初始化数据库
func Init() *gorm.DB {
cfg, err := loadConf()
func Init() {
err := orm.New(
DefaultDatabase,
UserDatabase,
)
if err != nil {
panic(fmt.Sprintf("load orm conf err: %v", err))
panic(fmt.Sprintf("new orm database err: %v", err))
}

DB = orm.New(cfg)
return DB
}

// GetDB 返回默认的数据库
func GetDB() *gorm.DB {
return DB
func GetDB() (*gorm.DB, error) {
return orm.GetDB(DefaultDatabase)
}

// loadConf load database config
func loadConf() (ret *orm.Config, err error) {
var cfg orm.Config
if err := config.Load("database", &cfg); err != nil {
return nil, err
}

return &cfg, nil
// GetUserDB 获取用户数据库实例
func GetUserDB() (*gorm.DB, error) {
return orm.GetDB(UserDatabase)
}
14 changes: 7 additions & 7 deletions internal/repository/user_follow_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (d *repository) UpdateUserFansStatus(ctx context.Context, db *gorm.DB, user
// GetFollowingUserList .
func (d *repository) GetFollowingUserList(ctx context.Context, userID, lastID uint64, limit int) ([]*model.UserFollowModel, error) {
userFollowList := make([]*model.UserFollowModel, 0)
db := model.GetDB()
db, _ := model.GetDB()
result := db.Where("user_id=? AND id<=? and status=1", userID, lastID).
Order("id desc").
Limit(limit).Find(&userFollowList)
Expand All @@ -56,7 +56,7 @@ func (d *repository) GetFollowingUserList(ctx context.Context, userID, lastID ui
// GetFollowerUserList get follower user list
func (d *repository) GetFollowerUserList(ctx context.Context, userID, lastID uint64, limit int) ([]*model.UserFansModel, error) {
userFollowerList := make([]*model.UserFansModel, 0)
db := model.GetDB()
db, _ := model.GetDB()
result := db.Where("user_id=? AND id<=? and status=1", userID, lastID).
Order("id desc").
Limit(limit).Find(&userFollowerList)
Expand All @@ -73,8 +73,8 @@ func (d *repository) GetFollowerUserList(ctx context.Context, userID, lastID uin
func (d *repository) GetFollowByUIds(ctx context.Context, userID uint64, followingUID []uint64) (map[uint64]*model.UserFollowModel, error) {
userFollowModel := make([]*model.UserFollowModel, 0)
retMap := make(map[uint64]*model.UserFollowModel)

err := model.GetDB().
db, _ := model.GetDB()
err := db.
Where("user_id=? AND followed_uid in (?) ", userID, followingUID).
Find(&userFollowModel).Error

Expand All @@ -93,12 +93,12 @@ func (d *repository) GetFollowByUIds(ctx context.Context, userID uint64, followi
func (d *repository) GetFansByUIds(ctx context.Context, userID uint64, followerUID []uint64) (map[uint64]*model.UserFansModel, error) {
userFansModel := make([]*model.UserFansModel, 0)
retMap := make(map[uint64]*model.UserFansModel)

err := model.GetDB().
db, _ := model.GetDB()
err := db.
Where("user_id=? AND follower_uid in (?) ", userID, followerUID).
Find(&userFansModel).Error

if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return retMap, errors.Wrap(err, "[user_follow] get user fans err")
}

Expand Down
7 changes: 4 additions & 3 deletions internal/service/relation_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ func newRelations(svc *service) *relationService {
// IsFollowing 是否正在关注某用户
func (s *relationService) IsFollowing(ctx context.Context, userID uint64, followedUID uint64) bool {
userFollowModel := &model.UserFollowModel{}
result := model.GetDB().
db, _ := model.GetDB()
result := db.
Where("user_id=? AND followed_uid=? ", userID, followedUID).
Find(userFollowModel)

Expand All @@ -57,7 +58,7 @@ func (s *relationService) IsFollowing(ctx context.Context, userID uint64, follow

// Follow 关注目标用户
func (s *relationService) Follow(ctx context.Context, userID uint64, followedUID uint64) error {
db := model.GetDB()
db, _ := model.GetDB()
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
Expand Down Expand Up @@ -103,7 +104,7 @@ func (s *relationService) Follow(ctx context.Context, userID uint64, followedUID

// Unfollow 取消用户关注
func (s *relationService) Unfollow(ctx context.Context, userID uint64, followedUID uint64) error {
db := model.GetDB()
db, _ := model.GetDB()
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
Expand Down
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ func main() {
// redis.Init()

// init service
service.Svc = service.New(repository.New(model.GetDB()))
db, _ := model.GetDB()
service.Svc = service.New(repository.New(db))

gin.SetMode(cfg.Mode)

Expand Down
107 changes: 105 additions & 2 deletions pkg/storage/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"fmt"
"log"
"os"
"sync"
"time"

"github.com/go-eagle/eagle/pkg/config"

otelgorm "github.com/1024casts/gorm-opentelemetry"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
Expand All @@ -21,6 +24,16 @@ const (
DriverMySQL = "mysql"
// DriverPostgres postgresSQL driver
DriverPostgres = "postgres"

// DefaultDatabase default db name
DefaultDatabase = "default"
)

var (
// DBMap store database instance
DBMap = make(map[string]*gorm.DB)
// DBLock database locker
DBLock sync.Mutex
)

// Config database config
Expand All @@ -37,8 +50,82 @@ type Config struct {
SlowThreshold time.Duration // 慢查询时长,默认500ms
}

// New connect to database and create a db instance
func New(c *Config) (db *gorm.DB) {
// New create a or multi database client
func New(names ...string) error {
if len(names) == 0 {
return fmt.Errorf("no set databasename")
}

clientManager := NewManager()
for _, name := range names {
_, err := clientManager.GetInstance(name)
if err != nil {
return fmt.Errorf("init database name: %+v, err: %+v", name, err)
}
}

return nil
}

// Manager define a manager
type Manager struct {
instances map[string]*gorm.DB
*sync.RWMutex
}

// NewManager create a database manager
func NewManager() *Manager {
return &Manager{
instances: make(map[string]*gorm.DB),
RWMutex: &sync.RWMutex{},
}
}

// GetDB get a database
func GetDB(name string) (*gorm.DB, error) {
DBLock.Lock()
defer DBLock.Unlock()

db, ok := DBMap[name]
if !ok {
db, err := NewManager().GetInstance(name)
if err != nil {
return nil, err
}
return db, nil
}

return db, nil
}

// GetInstance return a database client
func (m *Manager) GetInstance(name string) (*gorm.DB, error) {
// get client from map
m.RLock()
if ins, ok := m.instances[name]; ok {
m.RUnlock()
return ins, nil
}
m.RUnlock()

c, err := LoadConf(name)
if err != nil {
return nil, fmt.Errorf("load database conf err: %+v", err)
}

// create a database client
m.Lock()
defer m.Unlock()

instance := NewInstance(c)
m.instances[name] = instance
DBMap[name] = instance

return instance, nil
}

// NewInstance connect to database and create a db instance
func NewInstance(c *Config) (db *gorm.DB) {
var (
err error
sqlDB *sql.DB
Expand Down Expand Up @@ -83,6 +170,22 @@ func New(c *Config) (db *gorm.DB) {
return db
}

// LoadConf load database config
func LoadConf(name string) (ret *Config, err error) {
v, err := config.LoadWithType("database", "yaml")
if err != nil {
return nil, err
}

var c Config
err = v.UnmarshalKey(name, &c)
if err != nil {
return nil, err
}

return &c, nil
}

// getDSN return dsn string
func getDSN(c *Config) string {
// default mysql
Expand Down

0 comments on commit aa3b835

Please sign in to comment.