From 793060523908578803e513332dbfd31a4610d4e9 Mon Sep 17 00:00:00 2001 From: make123 Date: Thu, 5 Sep 2024 19:39:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(dbm-services):=20=E8=AF=AD=E6=B3=95?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E7=A6=81=E6=AD=A2=E6=93=8D=E4=BD=9C=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E5=BA=93=E8=A1=A8=20#6682?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/syntax/alter_table_rule.go | 8 +- .../app/syntax/create_db_rule.go | 13 +-- .../app/syntax/create_table_rule.go | 15 ++- .../app/syntax/spider_create_table_rule.go | 5 +- .../mysql/db-simulation/app/syntax/syntax.go | 106 +++++++++--------- .../{tmysqlpase.go => tmysqlpase_schema.go} | 42 ++++--- .../db-simulation/handler/dbsimulation.go | 2 +- dbm-services/mysql/db-simulation/main.go | 32 +++++- .../mysql/db-simulation/model/model.go | 8 +- .../db-simulation/model/tb_simulation_task.go | 2 +- .../mysql/db-simulation/pkg/util/spider.go | 5 +- 11 files changed, 130 insertions(+), 108 deletions(-) rename dbm-services/mysql/db-simulation/app/syntax/{tmysqlpase.go => tmysqlpase_schema.go} (90%) diff --git a/dbm-services/mysql/db-simulation/app/syntax/alter_table_rule.go b/dbm-services/mysql/db-simulation/app/syntax/alter_table_rule.go index 2b95233268..0ab49f7654 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/alter_table_rule.go +++ b/dbm-services/mysql/db-simulation/app/syntax/alter_table_rule.go @@ -11,8 +11,7 @@ package syntax import ( - util "dbm-services/common/go-pubpkg/cmutil" - "dbm-services/common/go-pubpkg/logger" + "github.com/samber/lo" ) // Checker syntax checker @@ -24,7 +23,6 @@ func (c AlterTableResult) Checker(mysqlVersion string) (r *CheckerResult) { r.Parse(R.AlterTableRule.AlterUseAfter, altercmd.After, "") // 如果是增加字段,需要判断增加的字段名称是否是关键字 if altercmd.Type == AlterTypeAddColumn { - logger.Info("col name is %s", altercmd.ColDef.ColName) r.ParseBultinRisk(func() (bool, string) { return KeyWordValidator(mysqlVersion, altercmd.ColDef.ColName) }) @@ -41,11 +39,11 @@ func (c AlterTableResult) Checker(mysqlVersion string) (r *CheckerResult) { // 去重后得到所有的alter types func (c AlterTableResult) GetAllAlterType() (alterTypes []string) { for _, a := range c.AlterCommands { - if !util.StringsHas([]string{"algorithm", "lock"}, a.Type) { + if !lo.Contains([]string{"algorithm", "lock"}, a.Type) { alterTypes = append(alterTypes, a.Type) } } - return util.RemoveDuplicate(alterTypes) + return lo.Uniq(alterTypes) } // GetPkAlterType get the primary key change type diff --git a/dbm-services/mysql/db-simulation/app/syntax/create_db_rule.go b/dbm-services/mysql/db-simulation/app/syntax/create_db_rule.go index e9cd49f4ed..4a822ac5d5 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/create_db_rule.go +++ b/dbm-services/mysql/db-simulation/app/syntax/create_db_rule.go @@ -10,9 +10,7 @@ package syntax -import "dbm-services/common/go-pubpkg/cmutil" - -// Checker TODO +// Checker create db syntax checker func (c CreateDBResult) Checker(mysqlVersion string) (r *CheckerResult) { r = &CheckerResult{} // 检查库名规范 @@ -26,17 +24,10 @@ func (c CreateDBResult) Checker(mysqlVersion string) (r *CheckerResult) { return SpecialCharValidator(c.DbName) }) } - // 不允许包含系统库 - r.ParseBultinBan(func() (bool, string) { - if cmutil.HasElem(c.DbName, cmutil.GetGcsSystemDatabasesIgnoreTest(mysqlVersion)) { - return true, "不允许操作系统库" + c.DbName - } - return false, "" - }) return } -// SpiderChecker TODO +// SpiderChecker spider create db syntax checker func (c CreateDBResult) SpiderChecker(mysqlVersion string) (r *CheckerResult) { return c.Checker(mysqlVersion) } diff --git a/dbm-services/mysql/db-simulation/app/syntax/create_table_rule.go b/dbm-services/mysql/db-simulation/app/syntax/create_table_rule.go index f0a3f4aeaa..64d38f591e 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/create_table_rule.go +++ b/dbm-services/mysql/db-simulation/app/syntax/create_table_rule.go @@ -11,10 +11,10 @@ package syntax import ( - "fmt" "strings" - "dbm-services/common/go-pubpkg/cmutil" + "github.com/samber/lo" + "dbm-services/common/go-pubpkg/logger" ) @@ -85,24 +85,27 @@ func (c CreateTableResult) GetTableCharset() (engine string) { // GetAllColCharsets get columns define charset func (c CreateTableResult) GetAllColCharsets() (charsets []string) { for _, colDef := range c.CreateDefinitions.ColDefs { - if !cmutil.IsEmpty(colDef.CharacterSet) { + if lo.IsNotEmpty(colDef.CharacterSet) { charsets = append(charsets, colDef.CharacterSet) } } - return cmutil.RemoveDuplicate(charsets) + return lo.Uniq(charsets) } // ColCharsetNotEqTbCharset 字段的字符集合和表的字符集合相同 func (c CreateTableResult) ColCharsetNotEqTbCharset() bool { colCharsets := c.GetAllColCharsets() - fmt.Println("colCharsets", colCharsets, len(colCharsets)) if len(colCharsets) == 0 { return false } if len(colCharsets) > 1 { return true } - if strings.Compare(strings.ToUpper(colCharsets[0]), c.GetTableCharset()) == 0 { + tableDefineCharset := c.GetTableCharset() + if lo.IsEmpty(tableDefineCharset) { + return false + } + if strings.Compare(strings.ToUpper(colCharsets[0]), tableDefineCharset) == 0 { return false } return true diff --git a/dbm-services/mysql/db-simulation/app/syntax/spider_create_table_rule.go b/dbm-services/mysql/db-simulation/app/syntax/spider_create_table_rule.go index a35a29e767..7ac48b613a 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/spider_create_table_rule.go +++ b/dbm-services/mysql/db-simulation/app/syntax/spider_create_table_rule.go @@ -14,7 +14,8 @@ import ( "slices" "strings" - "dbm-services/common/go-pubpkg/cmutil" + "github.com/samber/lo" + "dbm-services/common/go-pubpkg/logger" "dbm-services/mysql/db-simulation/pkg/util" ) @@ -67,7 +68,7 @@ func (c CreateTableResult) shardKeyChecker(r *CheckerResult) { } tableComment := c.GetComment() logger.Info("tableComment is %s", tableComment) - if cmutil.IsNotEmpty(tableComment) { + if lo.IsNotEmpty(tableComment) { // table comment 不为空的时候 先校验comment 格式是否合法 legal, msg := c.validateSpiderComment(tableComment) if !legal { diff --git a/dbm-services/mysql/db-simulation/app/syntax/syntax.go b/dbm-services/mysql/db-simulation/app/syntax/syntax.go index fba5c3e297..66f9751b1f 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/syntax.go +++ b/dbm-services/mysql/db-simulation/app/syntax/syntax.go @@ -40,9 +40,6 @@ type CheckSyntax interface { Do() (result map[string]*CheckInfo, err error) } -type inputFileName = string -type outputFileName = string - // TmysqlParseSQL execution parsing sql type TmysqlParseSQL struct { TmysqlParse @@ -65,7 +62,7 @@ type CheckSQLFileParam struct { // TmysqlParse TODO type TmysqlParse struct { - runtimeCtx + tmpWorkdir string result map[string]*CheckInfo bkRepoClient *bkrepo.BkRepoClient TmysqlParseBinPath string @@ -87,11 +84,6 @@ func (t *TmysqlParse) AddFileResult(fileName string, result *CheckInfo, failedIn t.mu.Unlock() } -type runtimeCtx struct { - fileMap map[inputFileName]outputFileName - tmpWorkdir string -} - // CheckInfo 语法检查结果信息汇总 type CheckInfo struct { SyntaxFailInfos []FailedInfo `json:"syntax_fails"` @@ -125,7 +117,6 @@ const DdlMapFileSubffix = ".tbl.map" // @return err func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[string]*CheckInfo, err error) { logger.Info("doing....") - tf.fileMap = make(map[inputFileName]outputFileName) tf.result = make(map[string]*CheckInfo) tf.tmpWorkdir = tf.BaseWorkdir tf.mu = sync.Mutex{} @@ -141,7 +132,6 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri } } // 最后删除临时目录,不会返回错误 - // 暂时屏蔽 观察过程文件 defer tf.delTempDir() var errs []error @@ -149,7 +139,6 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri if err = tf.doSingleVersion(dbtype, version); err != nil { logger.Error("when do [%s],syntax check,failed:%s", version, err.Error()) errs = append(errs, err) - // return tf.result, err } } @@ -158,22 +147,21 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri func (tf *TmysqlParseFile) doSingleVersion(dbtype string, mysqlVersion string) (err error) { errChan := make(chan error) - resultfileChan := make(chan string, 10) + alreadExecutedSqlfileChan := make(chan string, 10) signalChan := make(chan struct{}) go func() { - if err = tf.Execute(resultfileChan, mysqlVersion); err != nil { + if err = tf.Execute(alreadExecutedSqlfileChan, mysqlVersion); err != nil { logger.Error("failed to execute tmysqlparse: %s", err.Error()) errChan <- err } - close(resultfileChan) + close(alreadExecutedSqlfileChan) }() // 对tmysqlparse的处理结果进行分析,为json文件,后面用到了rule go func() { logger.Info("start to analyze the parsing result") - - if err = tf.AnalyzeParseResult(resultfileChan, mysqlVersion, dbtype); err != nil { + if err = tf.AnalyzeParseResult(alreadExecutedSqlfileChan, mysqlVersion, dbtype); err != nil { logger.Error("failed to analyze the parsing result:%s", err.Error()) errChan <- err } @@ -193,7 +181,6 @@ func (tf *TmysqlParseFile) doSingleVersion(dbtype string, mysqlVersion string) ( // CreateAndUploadDDLTblFile CreateAndUploadDDLTblFile func (tf *TmysqlParseFile) CreateAndUploadDDLTblFile() (err error) { logger.Info("start to create and upload ddl table file") - logger.Info("doing....") if err = tf.Init(); err != nil { logger.Error("Do init failed %s", err.Error()) return err @@ -218,7 +205,7 @@ func (tf *TmysqlParseFile) CreateAndUploadDDLTblFile() (err error) { }() for inputFileName := range resultfileChan { - if err = tf.analyzeDDLTbls(inputFileName); err != nil { + if err = tf.analyzeDDLTbls(inputFileName, ""); err != nil { logger.Error("failed to analyzeDDLTbls %s,err:%s", inputFileName, err.Error()) return err } @@ -260,7 +247,6 @@ func (t *TmysqlParse) Init() (err error) { return fmt.Errorf("failed to initialize tmysqlparse temporary directory(%s).detail:%s", t.tmpWorkdir, err.Error()) } t.bkRepoClient = getbkrepoClient() - t.fileMap = make(map[inputFileName]outputFileName) t.result = make(map[string]*CheckInfo) return nil } @@ -271,22 +257,6 @@ func (t *TmysqlParse) delTempDir() { } } -func (t *TmysqlParse) getCommand(filename, version string) (cmd string) { - var in, out string - in = path.Join(t.tmpWorkdir, filename) - if outputFileName, ok := t.fileMap[filename]; ok { - out = path.Join(t.tmpWorkdir, outputFileName) - } - - cmd = fmt.Sprintf(`%s --sql-file=%s --output-path=%s --print-query-mode=2 --output-format='JSON_LINE_PER_OBJECT'`, - t.TmysqlParseBinPath, in, out) - - if strings.TrimSpace(version) != "" { - cmd += fmt.Sprintf(" --mysql-version=%s ", version) - } - return -} - // Downloadfile download sqlfile func (tf *TmysqlParseFile) Downloadfile() (err error) { wg := &sync.WaitGroup{} @@ -333,21 +303,40 @@ func (tf *TmysqlParseFile) UploadDdlTblMapFile() (err error) { return } +func getSQLParseResultFile(fileName, version string) string { + return fmt.Sprintf("%s-%s.json", version, fileName) +} + +func (t *TmysqlParse) getCommand(filename, version string) (cmd string) { + var in, out string + in = path.Join(t.tmpWorkdir, filename) + outputFileName := getSQLParseResultFile(filename, version) + out = path.Join(t.tmpWorkdir, outputFileName) + + cmd = fmt.Sprintf(`%s --sql-file=%s --output-path=%s --print-query-mode=2 --output-format='JSON_LINE_PER_OBJECT'`, + t.TmysqlParseBinPath, in, out) + + if lo.IsNotEmpty(version) { + cmd += fmt.Sprintf(" --mysql-version=%s ", version) + } + + return cmd +} + // Execute 运行tmysqlpase // // @receiver tf // @return err -func (tf *TmysqlParseFile) Execute(resultFile chan string, version string) (err error) { +func (tf *TmysqlParseFile) Execute(alreadExecutedSqlfileCh chan string, version string) (err error) { var wg sync.WaitGroup var errs []error c := make(chan struct{}, 10) errChan := make(chan error, 5) - for _, fileName := range tf.Param.FileNames { wg.Add(1) c <- struct{}{} - tf.fileMap[fileName] = version + "-" + fileName + ".json" go func(sqlfile, ver string) { + //nolint command := exec.Command("/bin/bash", "-c", tf.getCommand(sqlfile, ver)) logger.Info("command is %s", command) @@ -355,9 +344,8 @@ func (tf *TmysqlParseFile) Execute(resultFile chan string, version string) (err if err != nil { errChan <- fmt.Errorf("tmysqlparse.sh command run failed. error info:" + err.Error() + "," + string(output)) } else { - resultFile <- sqlfile + alreadExecutedSqlfileCh <- sqlfile } - <-c wg.Done() }(fileName, version) @@ -375,19 +363,20 @@ func (tf *TmysqlParseFile) Execute(resultFile chan string, version string) (err return errors.Join(errs...) } -func (t *TmysqlParse) getAbsoutputfilePath(inputFileName string) string { - fileAbPath, _ := filepath.Abs(path.Join(t.tmpWorkdir, t.fileMap[inputFileName])) +func (t *TmysqlParse) getAbsoutputfilePath(sqlFile, version string) string { + fileAbPath, _ := filepath.Abs(path.Join(t.tmpWorkdir, getSQLParseResultFile(sqlFile, version))) return fileAbPath } // AnalyzeParseResult 分析tmysqlparse 解析的结果 -func (t *TmysqlParse) AnalyzeParseResult(resultFile chan string, mysqlVersion string, dbtype string) (err error) { +func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, mysqlVersion string, + dbtype string) (err error) { var errs []error c := make(chan struct{}, 10) errChan := make(chan error, 5) wg := &sync.WaitGroup{} - for inputFileName := range resultFile { + for sqlfile := range alreadExecutedSqlfileCh { wg.Add(1) c <- struct{}{} go func(fileName string) { @@ -397,7 +386,7 @@ func (t *TmysqlParse) AnalyzeParseResult(resultFile chan string, mysqlVersion st errChan <- err } <-c - }(inputFileName) + }(sqlfile) } go func() { @@ -434,16 +423,15 @@ func (c *CheckInfo) parseResult(rule *RuleItem, res ParseLineQueryBase, ver stri } // analyzeDDLTbls 分析DDL语句 -func (t *TmysqlParse) analyzeDDLTbls(inputfileName string) (err error) { +func (t *TmysqlParse) analyzeDDLTbls(inputfileName, mysqlVersion string) (err error) { ddlTbls := make(map[string][]string) defer func() { if r := recover(); r != nil { logger.Error("panic error:%v,stack:%s", r, string(debug.Stack())) - logger.Error("Recovered. Error: %v", r) } }() t.result[inputfileName] = &CheckInfo{} - f, err := os.Open(t.getAbsoutputfilePath(inputfileName)) + f, err := os.Open(t.getAbsoutputfilePath(inputfileName, mysqlVersion)) if err != nil { logger.Error("open file failed %s", err.Error()) return err @@ -524,7 +512,7 @@ func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (er ddlTbls := make(map[string][]string) checkResult := &CheckInfo{} - f, err := os.Open(t.getAbsoutputfilePath(inputfileName)) + f, err := os.Open(t.getAbsoutputfilePath(inputfileName, mysqlVersion)) if err != nil { logger.Error("open file failed %s", err.Error()) return err @@ -546,7 +534,6 @@ func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (er if isPrefix { continue } - // 清空 bs := buf buf = []byte{} @@ -555,21 +542,29 @@ func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (er logger.Info("blank line skip") continue } - if err = json.Unmarshal(bs, &res); err != nil { logger.Error("json unmasrshal line:%s failed %s", string(bs), err.Error()) return err } - // ErrorCode !=0 就是语法错误 if res.ErrorCode != 0 { syntaxFailInfos = append(syntaxFailInfos, t.getSyntaxErrorResult(res, mysqlVersion)) continue } - + // 判断是否变更的是系统数据库 + if res.IsSysDb() { + t.mu.Lock() + checkResult.BanWarnings = append(checkResult.BanWarnings, RiskInfo{ + Line: int64(res.QueryId), + Sqltext: res.QueryString, + WarnInfo: fmt.Sprintf("disable operating sys db: %s", res.DbName), + }) + t.mu.Unlock() + continue + } + // tmysqlparse检查结果全部正确,开始判断语句是否符合定义的规则(即虽然语法正确,但语句可能是高危语句或禁用的命令) switch dbtype { case app.MySQL: - // tmysqlparse检查结果全部正确,开始判断语句是否符合定义的规则(即虽然语法正确,但语句可能是高危语句或禁用的命令) checkResult.parseResult(R.CommandRule.HighRiskCommandRule, res, mysqlVersion) checkResult.parseResult(R.CommandRule.BanCommandRule, res, mysqlVersion) err = checkResult.runcheck(res, bs, mysqlVersion) @@ -577,7 +572,6 @@ func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (er goto END } case app.Spider: - // tmysqlparse检查结果全部正确,开始判断语句是否符合定义的规则(即虽然语法正确,但语句可能是高危语句或禁用的命令) checkResult.parseResult(SR.CommandRule.HighRiskCommandRule, res, mysqlVersion) checkResult.parseResult(SR.CommandRule.BanCommandRule, res, mysqlVersion) err = checkResult.runSpidercheck(ddlTbls, res, bs, mysqlVersion) diff --git a/dbm-services/mysql/db-simulation/app/syntax/tmysqlpase.go b/dbm-services/mysql/db-simulation/app/syntax/tmysqlpase_schema.go similarity index 90% rename from dbm-services/mysql/db-simulation/app/syntax/tmysqlpase.go rename to dbm-services/mysql/db-simulation/app/syntax/tmysqlpase_schema.go index b99420a202..64032c102c 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/tmysqlpase.go +++ b/dbm-services/mysql/db-simulation/app/syntax/tmysqlpase_schema.go @@ -10,7 +10,13 @@ package syntax -import util "dbm-services/common/go-pubpkg/cmutil" +import ( + "strings" + + "github.com/samber/lo" + + util "dbm-services/common/go-pubpkg/cmutil" +) const ( // AlterTypeAddColumn add_column @@ -39,7 +45,7 @@ const ( SQLTypeUpdate = "update" ) -// ColDef TODO +// ColDef mysql column definition type ColDef struct { Type string `json:"type"` ColName string `json:"col_name"` @@ -59,7 +65,7 @@ type ColDef struct { ReferenceDefinition interface{} `json:"reference_definition"` } -// KeyDef TODO +// KeyDef mysql index definition type KeyDef struct { Type string `json:"type"` KeyName string `json:"key_name"` @@ -75,13 +81,13 @@ type KeyDef struct { ReferenceDefinition interface{} `json:"reference_definition"` } -// TableOption TODO +// TableOption mysql table option definition type TableOption struct { Key string `json:"key"` Value interface{} `json:"value"` } -// ConverTableOptionToMap TODO +// ConverTableOptionToMap convert table option to map func ConverTableOptionToMap(options []TableOption) map[string]interface{} { r := make(map[string]interface{}) for _, v := range options { @@ -92,7 +98,7 @@ func ConverTableOptionToMap(options []TableOption) map[string]interface{} { return r } -// CommDDLResult TODO +// CommDDLResult mysql common ddl tmysqlparse result type CommDDLResult struct { QueryID int `json:"query_id"` Command string `json:"command"` @@ -100,7 +106,7 @@ type CommDDLResult struct { TableName string `json:"table_name"` } -// CreateTableResult TODO +// CreateTableResult tmysqlparse create table result type CreateTableResult struct { QueryID int `json:"query_id"` Command string `json:"command"` @@ -119,7 +125,7 @@ type CreateTableResult struct { PartitionOptions interface{} `json:"partition_options"` } -// CreateDBResult TODO +// CreateDBResult tmysqlparse create db result type CreateDBResult struct { QueryID int `json:"query_id"` Command string `json:"command"` @@ -128,7 +134,7 @@ type CreateDBResult struct { Collate string `json:"collate"` } -// AlterTableResult TODO +// AlterTableResult tmysqlparse alter table result type AlterTableResult struct { QueryID int `json:"query_id"` Command string `json:"command"` @@ -138,7 +144,7 @@ type AlterTableResult struct { PartitionOptions interface{} `json:"partition_options"` } -// AlterCommand TODO +// AlterCommand tmysqlparse alter table result type AlterCommand struct { Type string `json:"type"` ColDef ColDef `json:"col_def,omitempty"` @@ -155,14 +161,14 @@ type AlterCommand struct { Lock string `json:"lock,omitempty"` } -// ChangeDbResult TODO +// ChangeDbResult mysqlparse change db result type ChangeDbResult struct { QueryID int `json:"query_id"` Command string `json:"command"` DbName string `json:"db_name"` } -// ErrorResult TODO +// ErrorResult syntax error result type ErrorResult struct { QueryID int `json:"query_id"` Command string `json:"command"` @@ -170,17 +176,18 @@ type ErrorResult struct { ErrorMsg string `json:"error_msg,omitempty"` } -// ParseBase TODO +// ParseBase parse base type ParseBase struct { QueryId int `json:"query_id"` Command string `json:"command"` QueryString string `json:"query_string,omitempty"` } -// ParseLineQueryBase TODO +// ParseLineQueryBase parse line query base type ParseLineQueryBase struct { QueryId int `json:"query_id"` Command string `json:"command"` + DbName string `json:"db_name,omitempty"` QueryString string `json:"query_string,omitempty"` ErrorCode int `json:"error_code,omitempty"` ErrorMsg string `json:"error_msg,omitempty"` @@ -188,7 +195,12 @@ type ParseLineQueryBase struct { MaxMySQLVersion int `json:"max_my_sql_version"` } -// UserHost TODO +// IsSysDb sql modify target db is sys db +func (p ParseLineQueryBase) IsSysDb() bool { + return lo.Contains([]string{"mysql", "information_schema", "performance_schema", "sys"}, strings.ToLower(p.DbName)) +} + +// UserHost user host type UserHost struct { User string `json:"user"` Host string `json:"host"` diff --git a/dbm-services/mysql/db-simulation/handler/dbsimulation.go b/dbm-services/mysql/db-simulation/handler/dbsimulation.go index b14f35048e..610b4d0268 100644 --- a/dbm-services/mysql/db-simulation/handler/dbsimulation.go +++ b/dbm-services/mysql/db-simulation/handler/dbsimulation.go @@ -121,7 +121,7 @@ func TendbClusterSimulation(r *gin.Context) { BaseParam: ¶m.BaseParam, Version: version, } - rootPwd := cmutil.RandStr(10) + rootPwd := cmutil.RandomString(10) if !service.DelPod { logger.Info("the pwd %s", rootPwd) } diff --git a/dbm-services/mysql/db-simulation/main.go b/dbm-services/mysql/db-simulation/main.go index c26ea38a36..7ca7d43cc6 100644 --- a/dbm-services/mysql/db-simulation/main.go +++ b/dbm-services/mysql/db-simulation/main.go @@ -12,9 +12,12 @@ package main import ( "bytes" + "context" "io" "net/http" "os" + "os/signal" + "syscall" "time" "github.com/gin-contrib/pprof" @@ -52,13 +55,38 @@ func main() { ctx.SecureJSON(http.StatusOK, map[string]interface{}{"buildstamp": buildstamp, "githash": githash, "version": version}) }) - if err := app.Run(config.GAppConfig.ListenAddr); err != nil { - logger.Fatal("app run error: %v", err) + + srv := &http.Server{ + Addr: config.GAppConfig.ListenAddr, + Handler: app, + ReadHeaderTimeout: 5 * time.Second, + } + go func() { + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Fatal("listen: %s\n", err) + } + }() + // Wait for interrupt signal to gracefully shutdown the server with + // a timeout of 5 seconds. + quit := make(chan os.Signal, 1) + // kill (no param) default send syscall.SIGTERM + // kill -2 is syscall.SIGINT + // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + logger.Info("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + //nolint + logger.Fatal("Server forced to shutdown: %v ", err) } + logger.Info("Server exiting\n") } func init() { logger.New(os.Stdout, true, logger.InfoLevel, map[string]string{}) + //nolint: errcheck defer logger.Sync() } diff --git a/dbm-services/mysql/db-simulation/model/model.go b/dbm-services/mysql/db-simulation/model/model.go index 8eadd3e54e..f85cf5195e 100644 --- a/dbm-services/mysql/db-simulation/model/model.go +++ b/dbm-services/mysql/db-simulation/model/model.go @@ -28,9 +28,6 @@ import ( // DB TODO var DB *gorm.DB -// SqlDB TODO -var SqlDB *sql.DB - func init() { user := config.GAppConfig.DbConf.User pwd := config.GAppConfig.DbConf.Pwd @@ -75,14 +72,13 @@ func openDB(username, password, addr, name string) *gorm.DB { true, "Local") var err error - // SqlDB是上面定义了全局变量 - SqlDB, err = sql.Open("mysql", dsn) + dbc, err := sql.Open("mysql", dsn) if err != nil { log.Fatalf("connect to mysql failed %s", err.Error()) return nil } db, err := gorm.Open(mysql.New(mysql.Config{ - Conn: SqlDB, + Conn: dbc, }), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: newLogger, diff --git a/dbm-services/mysql/db-simulation/model/tb_simulation_task.go b/dbm-services/mysql/db-simulation/model/tb_simulation_task.go index d5e34fb71f..b8f1c777c9 100644 --- a/dbm-services/mysql/db-simulation/model/tb_simulation_task.go +++ b/dbm-services/mysql/db-simulation/model/tb_simulation_task.go @@ -104,7 +104,7 @@ func CreateTask(taskid, requestid, version string, billTaskId string) (err error return fmt.Errorf("this task exists:%s", taskid) } if !errors.Is(err, gorm.ErrRecordNotFound) { - logger.Error("") + logger.Error("create task failed %s", err.Error()) return err } return DB.Create(&TbSimulationTask{ diff --git a/dbm-services/mysql/db-simulation/pkg/util/spider.go b/dbm-services/mysql/db-simulation/pkg/util/spider.go index b6907864d1..3d92bfdea5 100644 --- a/dbm-services/mysql/db-simulation/pkg/util/spider.go +++ b/dbm-services/mysql/db-simulation/pkg/util/spider.go @@ -33,11 +33,10 @@ func ParseGetShardKeyForSpider(tableComment string) (string, error) { } // find the beginning " - if pos < len(tableComment) && tableComment[pos] == '"' { - pos++ - } else { + if !(pos < len(tableComment) && tableComment[pos] == '"') { return "", errors.New("parse error") } + pos++ // find the ending " end := strings.Index(tableComment[pos:], "\"")