diff --git a/CHANGELOG.md b/CHANGELOG.md index 1964bec258..07f67308fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## v1.8.2 - feat: support PostgreSQL +- feat: support config multiple databases ## v1.8.1 - fix: GitHub workflow badge URL diff --git a/config/docker/database.yaml b/config/docker/database.yaml index a8ed86f127..f36f6b406e 100644 --- a/config/docker/database.yaml +++ b/config/docker/database.yaml @@ -1,10 +1,11 @@ -Driver: mysql # 驱动名称,目前支持 mysql,postgres,默认: mysql -Name: eagle # 数据库名称 -Addr: db:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306 -UserName: root -Password: root -ShowLog: true # 是否打印所有SQL日志 -MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池 -MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数 -ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些 -SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms \ No newline at end of file +default: + Driver: mysql # 驱动名称,目前支持 mysql,postgres,默认: mysql + Name: eagle # 数据库名称 + Addr: db:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306 + UserName: root + Password: root + ShowLog: true # 是否打印所有SQL日志 + MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池 + MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数 + ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些 + SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms \ No newline at end of file diff --git a/config/local/database.yaml b/config/local/database.yaml index 50be4d5d2b..d814ffc365 100644 --- a/config/local/database.yaml +++ b/config/local/database.yaml @@ -1,10 +1,22 @@ -Driver: mysql # 驱动名称,目前支持: mysql,postgres,默认: mysql -Name: eagle # 数据库名称 -Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306, pg:5432 -UserName: root -Password: 123456 -ShowLog: true # 是否打印所有SQL日志 -MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池 -MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数 -ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些 -SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms +default: + Driver: mysql # 驱动名称,目前支持: mysql,postgres,默认: mysql + Name: eagle # 数据库名称 + Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306, pg:5432 + UserName: root + Password: 123456 + ShowLog: true # 是否打印所有SQL日志 + MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池 + MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数 + ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些 + SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms +user: + Driver: mysql # 驱动名称,目前支持: mysql,postgres,默认: mysql + Name: eagle # 数据库名称 + Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称,eg: db:3306, pg:5432 + UserName: root + Password: 123456 + ShowLog: true # 是否打印所有SQL日志 + MaxIdleConn: 10 # 最大闲置的连接数,0意味着使用默认的大小2, 小于0表示不使用连接池 + MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数 + ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些 + SlowThreshold: 500ms # 慢查询阈值,设置后只打印慢查询日志,默认为200ms diff --git a/internal/handler/v1/user/get.go b/internal/handler/v1/user/get.go index c46a28038a..68cde4a6dc 100644 --- a/internal/handler/v1/user/get.go +++ b/internal/handler/v1/user/get.go @@ -2,7 +2,6 @@ package user import ( "errors" - "time" "github.com/gin-gonic/gin" "github.com/spf13/cast" @@ -44,7 +43,7 @@ func Get(c *gin.Context) { return } - time.Sleep(5 * time.Second) + //time.Sleep(5 * time.Second) response.Success(c, u) } diff --git a/internal/model/init.go b/internal/model/init.go index e2effff088..0ef0e6316e 100644 --- a/internal/model/init.go +++ b/internal/model/init.go @@ -5,35 +5,33 @@ import ( "gorm.io/gorm" - "github.com/go-eagle/eagle/pkg/config" "github.com/go-eagle/eagle/pkg/storage/orm" ) -// DB 数据库全局变量 -var DB *gorm.DB +const ( + // DefaultDatabase default database + DefaultDatabase = "default" + // UserDatabase user database + UserDatabase = "user" +) // Init 初始化数据库 -func Init() *gorm.DB { - cfg, err := loadConf() +func Init() { + err := orm.New( + DefaultDatabase, + UserDatabase, + ) if err != nil { - panic(fmt.Sprintf("load orm conf err: %v", err)) + panic(fmt.Sprintf("new orm database err: %v", err)) } - - DB = orm.New(cfg) - return DB } // GetDB 返回默认的数据库 -func GetDB() *gorm.DB { - return DB +func GetDB() (*gorm.DB, error) { + return orm.GetDB(DefaultDatabase) } -// loadConf load database config -func loadConf() (ret *orm.Config, err error) { - var cfg orm.Config - if err := config.Load("database", &cfg); err != nil { - return nil, err - } - - return &cfg, nil +// GetUserDB 获取用户数据库实例 +func GetUserDB() (*gorm.DB, error) { + return orm.GetDB(UserDatabase) } diff --git a/internal/repository/user_follow_repo.go b/internal/repository/user_follow_repo.go index c205bc9fdc..9fff4232c0 100644 --- a/internal/repository/user_follow_repo.go +++ b/internal/repository/user_follow_repo.go @@ -40,7 +40,7 @@ func (d *repository) UpdateUserFansStatus(ctx context.Context, db *gorm.DB, user // GetFollowingUserList . func (d *repository) GetFollowingUserList(ctx context.Context, userID, lastID uint64, limit int) ([]*model.UserFollowModel, error) { userFollowList := make([]*model.UserFollowModel, 0) - db := model.GetDB() + db, _ := model.GetDB() result := db.Where("user_id=? AND id<=? and status=1", userID, lastID). Order("id desc"). Limit(limit).Find(&userFollowList) @@ -56,7 +56,7 @@ func (d *repository) GetFollowingUserList(ctx context.Context, userID, lastID ui // GetFollowerUserList get follower user list func (d *repository) GetFollowerUserList(ctx context.Context, userID, lastID uint64, limit int) ([]*model.UserFansModel, error) { userFollowerList := make([]*model.UserFansModel, 0) - db := model.GetDB() + db, _ := model.GetDB() result := db.Where("user_id=? AND id<=? and status=1", userID, lastID). Order("id desc"). Limit(limit).Find(&userFollowerList) @@ -73,8 +73,8 @@ func (d *repository) GetFollowerUserList(ctx context.Context, userID, lastID uin func (d *repository) GetFollowByUIds(ctx context.Context, userID uint64, followingUID []uint64) (map[uint64]*model.UserFollowModel, error) { userFollowModel := make([]*model.UserFollowModel, 0) retMap := make(map[uint64]*model.UserFollowModel) - - err := model.GetDB(). + db, _ := model.GetDB() + err := db. Where("user_id=? AND followed_uid in (?) ", userID, followingUID). Find(&userFollowModel).Error @@ -93,12 +93,12 @@ func (d *repository) GetFollowByUIds(ctx context.Context, userID uint64, followi func (d *repository) GetFansByUIds(ctx context.Context, userID uint64, followerUID []uint64) (map[uint64]*model.UserFansModel, error) { userFansModel := make([]*model.UserFansModel, 0) retMap := make(map[uint64]*model.UserFansModel) - - err := model.GetDB(). + db, _ := model.GetDB() + err := db. Where("user_id=? AND follower_uid in (?) ", userID, followerUID). Find(&userFansModel).Error - if err != nil && err != gorm.ErrRecordNotFound { + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return retMap, errors.Wrap(err, "[user_follow] get user fans err") } diff --git a/internal/service/relation_service.go b/internal/service/relation_service.go index 071e2c0e4d..1913fc3c13 100644 --- a/internal/service/relation_service.go +++ b/internal/service/relation_service.go @@ -39,7 +39,8 @@ func newRelations(svc *service) *relationService { // IsFollowing 是否正在关注某用户 func (s *relationService) IsFollowing(ctx context.Context, userID uint64, followedUID uint64) bool { userFollowModel := &model.UserFollowModel{} - result := model.GetDB(). + db, _ := model.GetDB() + result := db. Where("user_id=? AND followed_uid=? ", userID, followedUID). Find(userFollowModel) @@ -57,7 +58,7 @@ func (s *relationService) IsFollowing(ctx context.Context, userID uint64, follow // Follow 关注目标用户 func (s *relationService) Follow(ctx context.Context, userID uint64, followedUID uint64) error { - db := model.GetDB() + db, _ := model.GetDB() tx := db.Begin() defer func() { if r := recover(); r != nil { @@ -103,7 +104,7 @@ func (s *relationService) Follow(ctx context.Context, userID uint64, followedUID // Unfollow 取消用户关注 func (s *relationService) Unfollow(ctx context.Context, userID uint64, followedUID uint64) error { - db := model.GetDB() + db, _ := model.GetDB() tx := db.Begin() defer func() { if r := recover(); r != nil { diff --git a/main.go b/main.go index 5882a3ac85..ee3cbcd6f3 100644 --- a/main.go +++ b/main.go @@ -76,7 +76,8 @@ func main() { // redis.Init() // init service - service.Svc = service.New(repository.New(model.GetDB())) + db, _ := model.GetDB() + service.Svc = service.New(repository.New(db)) gin.SetMode(cfg.Mode) diff --git a/pkg/storage/orm/orm.go b/pkg/storage/orm/orm.go index f61dbd37dc..46e1bbbdb4 100644 --- a/pkg/storage/orm/orm.go +++ b/pkg/storage/orm/orm.go @@ -5,8 +5,11 @@ import ( "fmt" "log" "os" + "sync" "time" + "github.com/go-eagle/eagle/pkg/config" + otelgorm "github.com/1024casts/gorm-opentelemetry" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -21,6 +24,16 @@ const ( DriverMySQL = "mysql" // DriverPostgres postgresSQL driver DriverPostgres = "postgres" + + // DefaultDatabase default db name + DefaultDatabase = "default" +) + +var ( + // DBMap store database instance + DBMap = make(map[string]*gorm.DB) + // DBLock database locker + DBLock sync.Mutex ) // Config database config @@ -37,8 +50,82 @@ type Config struct { SlowThreshold time.Duration // 慢查询时长,默认500ms } -// New connect to database and create a db instance -func New(c *Config) (db *gorm.DB) { +// New create a or multi database client +func New(names ...string) error { + if len(names) == 0 { + return fmt.Errorf("no set databasename") + } + + clientManager := NewManager() + for _, name := range names { + _, err := clientManager.GetInstance(name) + if err != nil { + return fmt.Errorf("init database name: %+v, err: %+v", name, err) + } + } + + return nil +} + +// Manager define a manager +type Manager struct { + instances map[string]*gorm.DB + *sync.RWMutex +} + +// NewManager create a database manager +func NewManager() *Manager { + return &Manager{ + instances: make(map[string]*gorm.DB), + RWMutex: &sync.RWMutex{}, + } +} + +// GetDB get a database +func GetDB(name string) (*gorm.DB, error) { + DBLock.Lock() + defer DBLock.Unlock() + + db, ok := DBMap[name] + if !ok { + db, err := NewManager().GetInstance(name) + if err != nil { + return nil, err + } + return db, nil + } + + return db, nil +} + +// GetInstance return a database client +func (m *Manager) GetInstance(name string) (*gorm.DB, error) { + // get client from map + m.RLock() + if ins, ok := m.instances[name]; ok { + m.RUnlock() + return ins, nil + } + m.RUnlock() + + c, err := LoadConf(name) + if err != nil { + return nil, fmt.Errorf("load database conf err: %+v", err) + } + + // create a database client + m.Lock() + defer m.Unlock() + + instance := NewInstance(c) + m.instances[name] = instance + DBMap[name] = instance + + return instance, nil +} + +// NewInstance connect to database and create a db instance +func NewInstance(c *Config) (db *gorm.DB) { var ( err error sqlDB *sql.DB @@ -83,6 +170,22 @@ func New(c *Config) (db *gorm.DB) { return db } +// LoadConf load database config +func LoadConf(name string) (ret *Config, err error) { + v, err := config.LoadWithType("database", "yaml") + if err != nil { + return nil, err + } + + var c Config + err = v.UnmarshalKey(name, &c) + if err != nil { + return nil, err + } + + return &c, nil +} + // getDSN return dsn string func getDSN(c *Config) string { // default mysql