Skip to content

Commit

Permalink
Merge pull request #116 from actiontech/main
Browse files Browse the repository at this point in the history
Main
  • Loading branch information
rocky114 authored Jan 19, 2024
2 parents db9f127 + 1bc0e5b commit 19ee2c1
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 30 deletions.
125 changes: 98 additions & 27 deletions internal/dms/biz/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type CloudbeaverUser struct {

type CloudbeaverConnection struct {
DMSDBServiceID string `json:"dms_db_service_id"`
DMSUserId string `json:"dms_user_id"`
DMSDBServiceFingerprint string `json:"dms_db_service_fingerprint"`
CloudbeaverConnectionID string `json:"cloudbeaver_connection_id"`
}
Expand All @@ -61,8 +62,11 @@ type CloudbeaverRepo interface {
GetCloudbeaverUserByID(ctx context.Context, cloudbeaverUserId string) (*CloudbeaverUser, bool, error)
UpdateCloudbeaverUserCache(ctx context.Context, u *CloudbeaverUser) error
GetDbServiceIdByConnectionId(ctx context.Context, connectionId string) (string, error)
GetCloudbeaverConnectionByDMSDBServiceIds(ctx context.Context, dmsDBServiceIds []string) ([]*CloudbeaverConnection, error)
GetAllCloudbeaverConnections(ctx context.Context) ([]*CloudbeaverConnection, error)
GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx context.Context, userId string, dmsDBServiceIds []string) ([]*CloudbeaverConnection, error)
GetCloudbeaverConnectionsByUserId(ctx context.Context, userId string) ([]*CloudbeaverConnection, error)
UpdateCloudbeaverConnectionCache(ctx context.Context, u *CloudbeaverConnection) error
DeleteCloudbeaverConnectionCache(ctx context.Context, dbServiceId, userId string) error
}

type CloudbeaverUsecase struct {
Expand Down Expand Up @@ -561,6 +565,10 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
return err
}

if len(activeDBServices) == 0 {
return cu.clearConnection(ctx)
}

activeDBServices, err = cu.ResetDbServiceByAuth(ctx, activeDBServices, dmsUser.UID)
if err != nil {
return err
Expand Down Expand Up @@ -605,11 +613,7 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
activeDBServices = lastActiveDBServices
}

if len(activeDBServices) == 0 {
return nil
}

if err = cu.createConnection(ctx, activeDBServices); err != nil {
if err = cu.operateConnection(ctx, activeDBServices, dmsUser.UID); err != nil {
return err
}

Expand All @@ -627,7 +631,7 @@ func (cu *CloudbeaverUsecase) connectManagement(ctx context.Context, cloudbeaver
return nil
}

func (cu *CloudbeaverUsecase) createConnection(ctx context.Context, activeDBServices []*DBService) error {
func (cu *CloudbeaverUsecase) operateConnection(ctx context.Context, activeDBServices []*DBService, userId string) error {
dbServiceIds := make([]string, 0, len(activeDBServices))
dbServiceMap := map[string]*DBService{}
projectMap := map[string]string{}
Expand All @@ -644,28 +648,40 @@ func (cu *CloudbeaverUsecase) createConnection(ctx context.Context, activeDBServ
}
}

cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionByDMSDBServiceIds(ctx, dbServiceIds)
//获取当前用户所有已创建的连接
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserId(ctx, userId)
if err != nil {
return err
}

var deleteConnection []string

cloudbeaverConnectionMap := map[string]*CloudbeaverConnection{}
for _, service := range cloudbeaverConnections {
cloudbeaverConnectionMap[service.DMSDBServiceID] = service
for _, connection := range cloudbeaverConnections {
// 删除用户关联的连接
if connection.DMSUserId == userId {
cloudbeaverConnectionMap[connection.DMSDBServiceID] = connection

if _, ok := dbServiceMap[connection.DMSDBServiceID]; !ok {
deleteConnection = append(deleteConnection, connection.DMSDBServiceID)
}
}
}

var createConnection []string
var updateConnection []string

for dbServiceId, dbService := range dbServiceMap {
if cloudbeaverConnection, ok := cloudbeaverConnectionMap[dbServiceId]; !ok {
if cloudbeaverConnection, ok := cloudbeaverConnectionMap[dbServiceId]; ok {
if cloudbeaverConnection.DMSDBServiceFingerprint != cu.dbServiceUsecase.GetDBServiceFingerprint(dbService) {
updateConnection = append(updateConnection, dbService.UID)
}
} else {
createConnection = append(createConnection, dbService.UID)
} else if cloudbeaverConnection.DMSDBServiceFingerprint != cu.dbServiceUsecase.GetDBServiceFingerprint(dbService) {
updateConnection = append(updateConnection, dbService.UID)
}
}

if len(createConnection) == 0 && len(updateConnection) == 0 {
if len(createConnection) == 0 && len(updateConnection) == 0 && len(deleteConnection) == 0 {
return nil
}

Expand All @@ -677,17 +693,46 @@ func (cu *CloudbeaverUsecase) createConnection(ctx context.Context, activeDBServ

// 同步实例连接信息
for _, dbServiceId := range createConnection {
if err = cu.createCloudbeaverConnection(ctx, cloudbeaverClient, dbServiceMap[dbServiceId], projectMap[dbServiceId]); err != nil {
if err = cu.createCloudbeaverConnection(ctx, cloudbeaverClient, dbServiceMap[dbServiceId], projectMap[dbServiceId], userId); err != nil {
cu.log.Errorf("create dnServerId %s connection failed: %v", dbServiceId, err)
}
}

for _, dbServiceId := range updateConnection {
if err = cu.updateCloudbeaverConnection(ctx, cloudbeaverClient, cloudbeaverConnectionMap[dbServiceId].CloudbeaverConnectionID, dbServiceMap[dbServiceId], projectMap[dbServiceId]); err != nil {
if err = cu.updateCloudbeaverConnection(ctx, cloudbeaverClient, cloudbeaverConnectionMap[dbServiceId].CloudbeaverConnectionID, dbServiceMap[dbServiceId], projectMap[dbServiceId], userId); err != nil {
cu.log.Errorf("update dnServerId %s to connection failed: %v", dbServiceId, err)
}
}

for _, dbServiceId := range deleteConnection {
if err = cu.deleteCloudbeaverConnection(ctx, cloudbeaverClient, cloudbeaverConnectionMap[dbServiceId].CloudbeaverConnectionID, dbServiceId, userId); err != nil {
cu.log.Errorf("delete dbServerId %s to connection failed: %v", dbServiceId, err)
}
}

return nil
}

func (cu *CloudbeaverUsecase) clearConnection(ctx context.Context) error {
cloudbeaverConnections, err := cu.repo.GetAllCloudbeaverConnections(ctx)
if err != nil {
return err
}

// 获取管理员链接
cloudbeaverClient, err := cu.getGraphQLClientWithRootUser()
if err != nil {
return err
}

for _, item := range cloudbeaverConnections {
if err = cu.deleteCloudbeaverConnection(ctx, cloudbeaverClient, item.CloudbeaverConnectionID, item.DMSDBServiceID, ""); err != nil {
cu.log.Errorf("delete dbServerId %s to connection failed: %v", item.DMSDBServiceID, err)

return fmt.Errorf("delete dbServerId %s to connection failed: %v", item.DMSDBServiceID, err)
}
}

return nil
}

Expand All @@ -696,11 +741,16 @@ func (cu *CloudbeaverUsecase) grantAccessConnection(ctx context.Context, cloudbe
return fmt.Errorf("user information is not synchronized, unable to update connection information")
}

// 清空绑定能访问的数据库连接
if len(activeDBServices) == 0 {
return cu.bindUserAccessConnection(ctx, []*CloudbeaverConnection{}, cloudbeaverUser.CloudbeaverUserID)
}

dbServiceIds := make([]string, 0, len(activeDBServices))
for _, dbService := range activeDBServices {
dbServiceIds = append(dbServiceIds, dbService.UID)
}
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionByDMSDBServiceIds(ctx, dbServiceIds)
cloudbeaverConnections, err := cu.repo.GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx, dmsUser.UID, dbServiceIds)
if err != nil {
return err
}
Expand Down Expand Up @@ -742,7 +792,7 @@ func (cu *CloudbeaverUsecase) grantAccessConnection(ctx context.Context, cloudbe
}

func (cu *CloudbeaverUsecase) bindUserAccessConnection(ctx context.Context, cloudbeaverDBServices []*CloudbeaverConnection, cloudBeaverUserID string) error {
var cloudbeaverConnectionIds []string
var cloudbeaverConnectionIds = make([]string, 0, len(cloudbeaverDBServices))
for _, service := range cloudbeaverDBServices {
cloudbeaverConnectionIds = append(cloudbeaverConnectionIds, service.CloudbeaverConnectionID)
}
Expand All @@ -760,8 +810,8 @@ func (cu *CloudbeaverUsecase) bindUserAccessConnection(ctx context.Context, clou
return rootClient.Run(ctx, cloudbeaverConnReq, nil)
}

func (cu *CloudbeaverUsecase) createCloudbeaverConnection(ctx context.Context, client *cloudbeaver.Client, dbService *DBService, project string) error {
params, err := cu.GenerateCloudbeaverConnectionParams(dbService, project)
func (cu *CloudbeaverUsecase) createCloudbeaverConnection(ctx context.Context, client *cloudbeaver.Client, dbService *DBService, project, userId string) error {
params, err := cu.GenerateCloudbeaverConnectionParams(dbService, project, userId)
if err != nil {
return fmt.Errorf("%s unsupported", dbService.DBType)
}
Expand All @@ -782,14 +832,15 @@ func (cu *CloudbeaverUsecase) createCloudbeaverConnection(ctx context.Context, c
// 同步缓存
return cu.repo.UpdateCloudbeaverConnectionCache(ctx, &CloudbeaverConnection{
DMSDBServiceID: dbService.UID,
DMSUserId: userId,
DMSDBServiceFingerprint: cu.dbServiceUsecase.GetDBServiceFingerprint(dbService),
CloudbeaverConnectionID: resp.Connection.ID,
})
}

// UpdateCloudbeaverConnection 更新完毕后会同步缓存
func (cu *CloudbeaverUsecase) updateCloudbeaverConnection(ctx context.Context, client *cloudbeaver.Client, cloudbeaverConnectionId string, dbService *DBService, project string) error {
params, err := cu.GenerateCloudbeaverConnectionParams(dbService, project)
func (cu *CloudbeaverUsecase) updateCloudbeaverConnection(ctx context.Context, client *cloudbeaver.Client, cloudbeaverConnectionId string, dbService *DBService, project, userId string) error {
params, err := cu.GenerateCloudbeaverConnectionParams(dbService, project, userId)
if err != nil {
return fmt.Errorf("%s unsupported", dbService.DBType)
}
Expand All @@ -815,15 +866,33 @@ func (cu *CloudbeaverUsecase) updateCloudbeaverConnection(ctx context.Context, c

return cu.repo.UpdateCloudbeaverConnectionCache(ctx, &CloudbeaverConnection{
DMSDBServiceID: dbService.UID,
DMSUserId: userId,
DMSDBServiceFingerprint: cu.dbServiceUsecase.GetDBServiceFingerprint(dbService),
CloudbeaverConnectionID: resp.Connection.ID,
})
}

func (cu *CloudbeaverUsecase) generateCommonCloudbeaverConfigParams(dbService *DBService, project string) map[string]interface{} {
func (cu *CloudbeaverUsecase) deleteCloudbeaverConnection(ctx context.Context, client *cloudbeaver.Client, cloudbeaverConnectionId, dbServiceId, userId string) error {
variables := make(map[string]interface{})
variables["connectionId"] = cloudbeaverConnectionId
variables["projectId"] = cloudbeaverProjectId

req := cloudbeaver.NewRequest(cu.graphQl.DeleteConnectionQuery(), variables)
resp := struct {
DeleteConnection bool `json:"deleteConnection"`
}{}

if err := client.Run(ctx, req, &resp); err != nil {
return err
}

return cu.repo.DeleteCloudbeaverConnectionCache(ctx, dbServiceId, userId)
}

func (cu *CloudbeaverUsecase) generateCommonCloudbeaverConfigParams(dbService *DBService, project, userId string) map[string]interface{} {
return map[string]interface{}{
"configurationType": "MANUAL",
"name": fmt.Sprintf("%v: %v", project, dbService.Name),
"name": fmt.Sprintf("%s:%s:%s", project, dbService.Name, userId),
"template": false,
"host": dbService.Host,
"port": dbService.Port,
Expand All @@ -838,9 +907,11 @@ func (cu *CloudbeaverUsecase) generateCommonCloudbeaverConfigParams(dbService *D
}
}

func (cu *CloudbeaverUsecase) GenerateCloudbeaverConnectionParams(dbService *DBService, project string) (map[string]interface{}, error) {
const cloudbeaverProjectId = "g_GlobalConfiguration"

func (cu *CloudbeaverUsecase) GenerateCloudbeaverConnectionParams(dbService *DBService, project string, userId string) (map[string]interface{}, error) {
var err error
config := cu.generateCommonCloudbeaverConfigParams(dbService, project)
config := cu.generateCommonCloudbeaverConfigParams(dbService, project, userId)

dbType, err := constant.ParseDBType(dbService.DBType)
if err != nil {
Expand All @@ -866,7 +937,7 @@ func (cu *CloudbeaverUsecase) GenerateCloudbeaverConnectionParams(dbService *DBS
}

resp := map[string]interface{}{
"projectId": "g_GlobalConfiguration",
"projectId": cloudbeaverProjectId,
"config": config,
}
return resp, err
Expand Down
49 changes: 47 additions & 2 deletions internal/dms/storage/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,42 @@ func (cr *CloudbeaverRepo) GetDbServiceIdByConnectionId(ctx context.Context, con
return cloudbeaverConnection.DMSDBServiceID, nil
}

func (cr *CloudbeaverRepo) GetCloudbeaverConnectionByDMSDBServiceIds(ctx context.Context, dmsDBServiceIds []string) ([]*biz.CloudbeaverConnection, error) {
func (cr *CloudbeaverRepo) GetAllCloudbeaverConnections(ctx context.Context) ([]*biz.CloudbeaverConnection, error) {
var cloudbeaverConnections []*model.CloudbeaverConnectionCache
err := transaction(cr.log, ctx, cr.db, func(tx *gorm.DB) error {
if err := tx.Find(&cloudbeaverConnections, dmsDBServiceIds).Error; err != nil {
if err := tx.Find(&cloudbeaverConnections).Error; err != nil {
return fmt.Errorf("failed to get cloudbeaver db service: %v", err)
}
return nil
})

if err != nil {
return nil, err
}

return convertModelCloudbeaverConnection(cloudbeaverConnections), nil
}

func (cr *CloudbeaverRepo) GetCloudbeaverConnectionsByUserIdAndDBServiceIds(ctx context.Context, userId string, dmsDBServiceIds []string) ([]*biz.CloudbeaverConnection, error) {
var cloudbeaverConnections []*model.CloudbeaverConnectionCache
err := transaction(cr.log, ctx, cr.db, func(tx *gorm.DB) error {
if err := tx.Where("dms_user_id = ? and dms_db_service_id in (?)", userId, dmsDBServiceIds).Find(&cloudbeaverConnections).Error; err != nil {
return fmt.Errorf("failed to get cloudbeaver db service: %v", err)
}
return nil
})

if err != nil {
return nil, err
}

return convertModelCloudbeaverConnection(cloudbeaverConnections), nil
}

func (cr *CloudbeaverRepo) GetCloudbeaverConnectionsByUserId(ctx context.Context, userId string) ([]*biz.CloudbeaverConnection, error) {
var cloudbeaverConnections []*model.CloudbeaverConnectionCache
err := transaction(cr.log, ctx, cr.db, func(tx *gorm.DB) error {
if err := tx.Where("dms_user_id = ?", userId).Find(&cloudbeaverConnections).Error; err != nil {
return fmt.Errorf("failed to get cloudbeaver db service: %v", err)
}
return nil
Expand All @@ -92,3 +124,16 @@ func (cr *CloudbeaverRepo) UpdateCloudbeaverConnectionCache(ctx context.Context,
return nil
})
}

func (cr *CloudbeaverRepo) DeleteCloudbeaverConnectionCache(ctx context.Context, dbServiceId, userId string) error {
return transaction(cr.log, ctx, cr.db, func(tx *gorm.DB) error {
db := tx.WithContext(ctx).Where("dms_db_service_id = ?", dbServiceId)
if len(userId) > 0 {
db = db.Where("dms_user_id = ?", userId)
}
if err := db.Delete(&model.CloudbeaverConnectionCache{}).Error; err != nil {
return fmt.Errorf("failed to delete cloudbeaver db Service: %v", err)
}
return nil
})
}
2 changes: 2 additions & 0 deletions internal/dms/storage/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ func convertBizCloudbeaverUser(u *biz.CloudbeaverUser) *model.CloudbeaverUserCac
func convertBizCloudbeaverConnection(u *biz.CloudbeaverConnection) *model.CloudbeaverConnectionCache {
return &model.CloudbeaverConnectionCache{
DMSDBServiceID: u.DMSDBServiceID,
DMSUserID: u.DMSUserId,
DMSDBServiceFingerprint: u.DMSDBServiceFingerprint,
CloudbeaverConnectionID: u.CloudbeaverConnectionID,
}
Expand Down Expand Up @@ -295,6 +296,7 @@ func convertModelCloudbeaverConnection(items []*model.CloudbeaverConnectionCache
for _, item := range items {
res = append(res, &biz.CloudbeaverConnection{
DMSDBServiceID: item.DMSDBServiceID,
DMSUserId: item.DMSUserID,
DMSDBServiceFingerprint: item.DMSDBServiceFingerprint,
CloudbeaverConnectionID: item.CloudbeaverConnectionID,
})
Expand Down
1 change: 1 addition & 0 deletions internal/dms/storage/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ type CloudbeaverUserCache struct {

type CloudbeaverConnectionCache struct {
DMSDBServiceID string `json:"dms_db_service_id" gorm:"column:dms_db_service_id;primaryKey"`
DMSUserID string `json:"dms_user_id" gorm:"column:dms_user_id;primaryKey"`
DMSDBServiceFingerprint string `json:"dms_db_service_fingerprint" gorm:"size:255;column:dms_db_service_fingerprint"`
CloudbeaverConnectionID string `json:"cloudbeaver_connection_id" gorm:"size:255;column:cloudbeaver_connection_id"`
}
Expand Down
3 changes: 2 additions & 1 deletion internal/dms/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func NewStorage(logger pkgLog.Logger, conf *StorageConfig) (*Storage, error) {

db, err := gorm.Open(mysql.Open(fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local",
conf.User, conf.Password, conf.Host, conf.Port, conf.Schema)), &gorm.Config{
Logger: pkgLog.NewGormLogWrapper(pkgLog.NewKLogWrapper(logger), gormLog.Info),
Logger: pkgLog.NewGormLogWrapper(pkgLog.NewKLogWrapper(logger), gormLog.Info),
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
log.Errorf("connect to storage failed, error: %v", err)
Expand Down
12 changes: 12 additions & 0 deletions internal/pkg/cloudbeaver/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (v CBVersion) LessThan(version CBVersion) bool {
type GraphQLImpl interface {
CreateConnectionQuery() string
UpdateConnectionQuery() string
DeleteConnectionQuery() string
GetUserConnectionsQuery() string
SetUserConnectionsQuery() string
IsUserExistQuery(userId string) (string, map[string]interface{})
Expand Down Expand Up @@ -115,6 +116,17 @@ fragment DatabaseConnection on ConnectionInfo {
`
}

func (CloudBeaverV2215) DeleteConnectionQuery() string {
return `
mutation deleteConnection(
$projectId: ID!
$connectionId: ID!
) {
deleteConnection(projectId: $projectId, id: $connectionId)
}
`
}

func (CloudBeaverV2215) GetUserConnectionsQuery() string {
return `
query getUserConnections (
Expand Down

0 comments on commit 19ee2c1

Please sign in to comment.