diff --git a/sqle/api/controller/v1/audit_plan.go b/sqle/api/controller/v1/audit_plan.go index bec71ee26b..0eeb2a8e23 100644 --- a/sqle/api/controller/v1/audit_plan.go +++ b/sqle/api/controller/v1/audit_plan.go @@ -6,9 +6,7 @@ import ( "encoding/csv" "fmt" "mime" - "net" "net/http" - "regexp" "strconv" "strings" "time" @@ -849,93 +847,6 @@ func GetAuditPlanReport(c echo.Context) error { }) } -func filterSQLsByBlackList(sqls []*AuditPlanSQLReqV1, blackList []*model.BlackListAuditPlanSQL) []*AuditPlanSQLReqV1 { - if len(blackList) == 0 { - return sqls - } - filteredSQLs := []*AuditPlanSQLReqV1{} - filter := ConvertToBlackFilter(blackList) - for _, sql := range sqls { - if filter.HasEndpointInBlackList([]string{sql.Endpoint}) || filter.IsSqlInBlackList(sql.LastReceiveText) { - continue - } - filteredSQLs = append(filteredSQLs, sql) - } - return filteredSQLs -} - -func ConvertToBlackFilter(blackList []*model.BlackListAuditPlanSQL) *BlackFilter { - var blackFilter BlackFilter - for _, filter := range blackList { - switch filter.FilterType { - case model.FilterTypeSQL: - blackFilter.BlackSqlList = append(blackFilter.BlackSqlList, utils.FullFuzzySearchRegexp(filter.FilterContent)) - case model.FilterTypeHost: - blackFilter.BlackHostList = append(blackFilter.BlackHostList, utils.FullFuzzySearchRegexp(filter.FilterContent)) - case model.FilterTypeIP: - ip := net.ParseIP(filter.FilterContent) - if ip == nil { - log.Logger().Errorf("wrong ip in black list,ip:%s", filter.FilterContent) - continue - } - blackFilter.BlackIpList = append(blackFilter.BlackIpList, ip) - case model.FilterTypeCIDR: - _, cidr, err := net.ParseCIDR(filter.FilterContent) - if err != nil { - log.Logger().Errorf("wrong cidr in black list,cidr:%s,err:%v", filter.FilterContent, err) - continue - } - blackFilter.BlackCidrList = append(blackFilter.BlackCidrList, cidr) - } - } - return &blackFilter -} - -// 构造BlackFilter的目的是缓存黑名单中需要使用的结构体,在每个循环中复用 -type BlackFilter struct { - BlackSqlList []*regexp.Regexp //更换正则匹配提高效率 - BlackIpList []net.IP - BlackHostList []*regexp.Regexp - BlackCidrList []*net.IPNet -} - -func (f BlackFilter) IsSqlInBlackList(checkSql string) bool { - for _, blackSql := range f.BlackSqlList { - if blackSql.MatchString(checkSql) { - return true - } - } - return false -} - -// 输入一组ip若其中有一个ip在黑名单中则返回true -func (f BlackFilter) HasEndpointInBlackList(checkIps []string) bool { - var checkNetIp net.IP - for _, checkIp := range checkIps { - checkNetIp = net.ParseIP(checkIp) - if checkNetIp == nil { - // 无法解析IP,可能是域名,需要正则匹配 - for _, blackHost := range f.BlackHostList { - if blackHost.MatchString(checkIp) { - return true - } - } - } else { - for _, blackIp := range f.BlackIpList { - if blackIp.Equal(checkNetIp) { - return true - } - } - for _, blackCidr := range f.BlackCidrList { - if blackCidr.Contains(checkNetIp) { - return true - } - } - } - } - return false -} - type FullSyncAuditPlanSQLsReqV1 struct { SQLs []*AuditPlanSQLReqV1 `json:"audit_plan_sql_list" form:"audit_plan_sql_list" valid:"dive"` } @@ -989,13 +900,7 @@ func FullSyncAuditPlanSQLs(c echo.Context) error { l := log.NewEntry() reqSQLs := req.SQLs - blackList, err := s.GetBlackListAuditPlanSQLs() - if err == nil { - reqSQLs = filterSQLsByBlackList(reqSQLs, blackList) - } else { - l.Warnf("blacklist is not used, err:%v", err) - } - if len(reqSQLs) == 0 { + if len(req.SQLs) == 0 { return controller.JSONBaseErrorReq(c, nil) } sqls, err := convertToModelAuditPlanSQL(c, ap, reqSQLs) @@ -1045,12 +950,6 @@ func PartialSyncAuditPlanSQLs(c echo.Context) error { l := log.NewEntry() reqSQLs := req.SQLs - blackList, err := s.GetBlackListAuditPlanSQLs() - if err == nil { - reqSQLs = filterSQLsByBlackList(reqSQLs, blackList) - } else { - l.Warnf("blacklist is not used, err:%v", err) - } if len(reqSQLs) == 0 { return controller.JSONBaseErrorReq(c, nil) } diff --git a/sqle/api/controller/v1/blacklist.go b/sqle/api/controller/v1/blacklist.go index b251aec59b..39f0d0bd29 100644 --- a/sqle/api/controller/v1/blacklist.go +++ b/sqle/api/controller/v1/blacklist.go @@ -1,9 +1,15 @@ package v1 import ( + "context" + "fmt" + "net/http" "time" "github.com/actiontech/sqle/sqle/api/controller" + "github.com/actiontech/sqle/sqle/dms" + "github.com/actiontech/sqle/sqle/errors" + "github.com/actiontech/sqle/sqle/model" "github.com/labstack/echo/v4" ) @@ -25,7 +31,27 @@ type CreateBlacklistReqV1 struct { // @Success 200 {object} controller.BaseRes // @router /v1/projects/{project_name}/blacklist [post] func CreateBlacklist(c echo.Context) error { - return nil + req := new(CreateBlacklistReqV1) + if err := controller.BindAndValidateReq(c, req); err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + projectUid, err := dms.GetPorjectUIDByName(context.TODO(), c.Param("project_name"), true) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + s := model.GetStorage() + err = s.Save(&model.BlackListAuditPlanSQL{ + ProjectId: model.ProjectUID(projectUid), + FilterType: model.BlacklistFilterType(req.Type), + FilterContent: req.Content, + Desc: req.Desc, + }) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + return c.JSON(http.StatusOK, controller.NewBaseReq(nil)) } // DeleteBlacklist @@ -38,7 +64,28 @@ func CreateBlacklist(c echo.Context) error { // @Success 200 {object} controller.BaseRes // @router /v1/projects/{project_name}/blacklist/{blacklist_id}/ [delete] func DeleteBlacklist(c echo.Context) error { - return nil + blacklistId := c.Param("blacklist_id") + + projectUid, err := dms.GetPorjectUIDByName(context.TODO(), c.Param("project_name")) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + s := model.GetStorage() + blacklist, exist, err := s.GetBlacklistByID(model.ProjectUID(projectUid), blacklistId) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + if !exist { + return controller.JSONBaseErrorReq(c, errors.New(errors.DataNotExist, + fmt.Errorf("blacklist is not exist"))) + } + + if err := s.Delete(blacklist); err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + return c.JSON(http.StatusOK, controller.NewBaseReq(nil)) } type UpdateBlacklistReqV1 struct { @@ -60,7 +107,43 @@ type UpdateBlacklistReqV1 struct { // @Success 200 {object} controller.BaseRes // @router /v1/projects/{project_name}/blacklist/{blacklist_id}/ [patch] func UpdateBlacklist(c echo.Context) error { - return nil + req := new(UpdateBlacklistReqV1) + if err := controller.BindAndValidateReq(c, req); err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + blacklistId := c.Param("blacklist_id") + projectUid, err := dms.GetPorjectUIDByName(context.TODO(), c.Param("project_name")) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + s := model.GetStorage() + blacklist, exist, err := s.GetBlacklistByID(model.ProjectUID(projectUid), blacklistId) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + if !exist { + return controller.JSONBaseErrorReq(c, errors.New(errors.DataNotExist, + fmt.Errorf("blacklist is not exist"))) + } + + if req.Content != nil { + blacklist.FilterContent = *req.Content + } + if req.Type != nil { + blacklist.FilterType = model.BlacklistFilterType(*req.Type) + } + if req.Desc != nil { + blacklist.Desc = *req.Desc + } + + err = s.Save(blacklist) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + return c.JSON(http.StatusOK, controller.NewBaseReq(nil)) } type GetBlacklistReqV1 struct { @@ -99,5 +182,37 @@ type BlacklistResV1 struct { // @Success 200 {object} v1.GetBlacklistResV1 // @router /v1/projects/{project_name}/blacklist [get] func GetBlacklist(c echo.Context) error { - return nil + req := new(GetBlacklistReqV1) + if err := controller.BindAndValidateReq(c, req); err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + projectUid, err := dms.GetPorjectUIDByName(context.TODO(), c.Param("project_name")) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + s := model.GetStorage() + blacklistList, count, err := s.GetBlacklistList(model.ProjectUID(projectUid), model.BlacklistFilterType(req.FilterType), req.FuzzySearchContent, req.PageIndex, req.PageSize) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + res := make([]*BlacklistResV1, 0, len(blacklistList)) + for _, blacklist := range blacklistList { + res = append(res, &BlacklistResV1{ + BlacklistID: blacklist.ID, + Content: blacklist.FilterContent, + Desc: blacklist.Desc, + Type: string(blacklist.FilterType), + MatchedCount: blacklist.MatchedCount, + LastMatchTime: blacklist.LastMatchTime, + }) + } + + return c.JSON(http.StatusOK, &GetBlacklistResV1{ + BaseRes: controller.NewBaseReq(nil), + Data: res, + TotalNums: count, + }) } diff --git a/sqle/api/controller/v1/sql_whitelist.go b/sqle/api/controller/v1/sql_whitelist.go index 63ae9e7f08..c553d6d8b6 100644 --- a/sqle/api/controller/v1/sql_whitelist.go +++ b/sqle/api/controller/v1/sql_whitelist.go @@ -153,10 +153,10 @@ func DeleteAuditWhitelistById(c echo.Context) error { } type GetAuditWhitelistReqV1 struct { - FuzzySearchValue string `json:"fuzzy_value" query:"fuzzy_value" valid:"omitempty"` - FilterMatchType string `json:"filter_match_type" query:"filter_match_type" valid:"omitempty,oneof=exact_match fp_match" enums:"exact_match,fp_match"` - PageIndex uint32 `json:"page_index" query:"page_index" valid:"required"` - PageSize uint32 `json:"page_size" query:"page_size" valid:"required"` + FuzzySearchValue *string `json:"fuzzy_value" query:"fuzzy_value" valid:"omitempty"` + FilterMatchType *string `json:"filter_match_type" query:"filter_match_type" valid:"omitempty,oneof=exact_match fp_match" enums:"exact_match,fp_match"` + PageIndex uint32 `json:"page_index" query:"page_index" valid:"required"` + PageSize uint32 `json:"page_size" query:"page_size" valid:"required"` } type GetAuditWhitelistResV1 struct { @@ -197,17 +197,19 @@ func GetSqlWhitelist(c echo.Context) error { } s := model.GetStorage() - sqlWhitelist, count, err := s.GetSqlWhitelistByProjectUID(req.PageIndex, req.PageSize, model.ProjectUID(projectUid)) + sqlWhitelist, count, err := s.GetSqlWhitelistByProjectUID(req.PageIndex, req.PageSize, model.ProjectUID(projectUid), req.FuzzySearchValue, req.FilterMatchType) if err != nil { return controller.JSONBaseErrorReq(c, err) } whitelistRes := make([]*AuditWhitelistResV1, 0, len(sqlWhitelist)) for _, v := range sqlWhitelist { whitelistRes = append(whitelistRes, &AuditWhitelistResV1{ - Id: v.ID, - Value: v.Value, - Desc: v.Desc, - MatchType: v.MatchType, + Id: v.ID, + Value: v.Value, + Desc: v.Desc, + MatchType: v.MatchType, + MatchedCount: uint(v.MatchedCount), + LastMatchTime: v.LastMatchedTime, }) } return c.JSON(http.StatusOK, &GetAuditWhitelistResV1{ diff --git a/sqle/api/controller/v2/audit_plan.go b/sqle/api/controller/v2/audit_plan.go index 35a40adebe..63a4268df3 100644 --- a/sqle/api/controller/v2/audit_plan.go +++ b/sqle/api/controller/v2/audit_plan.go @@ -312,21 +312,6 @@ type AuditPlanSQLReqV2 struct { Endpoints []string `json:"endpoints" from:"endpoints"` } -func filterSQLsByBlackList(sqls []*AuditPlanSQLReqV2, blackList []*model.BlackListAuditPlanSQL) []*AuditPlanSQLReqV2 { - if len(blackList) == 0 { - return sqls - } - filteredSQLs := []*AuditPlanSQLReqV2{} - filter := v1.ConvertToBlackFilter(blackList) - for _, sql := range sqls { - if filter.HasEndpointInBlackList(sql.Endpoints) || filter.IsSqlInBlackList(sql.LastReceiveText) { - continue - } - filteredSQLs = append(filteredSQLs, sql) - } - return filteredSQLs -} - func convertToModelAuditPlanSQL(dbType string, reqSQLs []*AuditPlanSQLReqV2) ([]*auditplan.SQL, error) { var p driver.Plugin var err error @@ -448,12 +433,6 @@ func PartialSyncAuditPlanSQLs(c echo.Context) error { l := log.NewEntry() reqSQLs := req.SQLs - blackList, err := s.GetBlackListAuditPlanSQLs() - if err == nil { - reqSQLs = filterSQLsByBlackList(reqSQLs, blackList) - } else { - l.Warnf("blacklist is not used, err:%v", err) - } if len(reqSQLs) == 0 { return controller.JSONBaseErrorReq(c, nil) } @@ -502,12 +481,6 @@ func FullSyncAuditPlanSQLs(c echo.Context) error { l := log.NewEntry() reqSQLs := req.SQLs - blackList, err := s.GetBlackListAuditPlanSQLs() - if err == nil { - reqSQLs = filterSQLsByBlackList(reqSQLs, blackList) - } else { - l.Warnf("blacklist is not used, err:%v", err) - } if len(reqSQLs) == 0 { return controller.JSONBaseErrorReq(c, nil) } @@ -543,6 +516,7 @@ func UploadInstanceAuditPlanSQLs(c echo.Context) error { if err != nil { return controller.JSONBaseErrorReq(c, err) } + s := model.GetStorage() ap, exist, err := s.GetActiveAuditPlanDetail(uint(apID)) @@ -554,13 +528,17 @@ func UploadInstanceAuditPlanSQLs(c echo.Context) error { } l := log.NewEntry() - reqSQLs := req.SQLs - blackList, err := s.GetBlackListAuditPlanSQLs() - if err == nil { - reqSQLs = filterSQLsByBlackList(reqSQLs, blackList) + instance, exist, err := dms.GetInstancesById(c.Request().Context(), ap.InstanceID) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + if exist { + ap.Instance = instance } else { - l.Warnf("blacklist is not used, err:%v", err) + l.Errorf("instance not found, instance id: %s", ap.InstanceID) } + + reqSQLs := req.SQLs if len(reqSQLs) == 0 { return controller.JSONBaseErrorReq(c, nil) } diff --git a/sqle/model/audit_plan.go b/sqle/model/audit_plan.go index cab8bff3aa..3b0c59b8c9 100644 --- a/sqle/model/audit_plan.go +++ b/sqle/model/audit_plan.go @@ -2,6 +2,7 @@ package model import ( "encoding/json" + e "errors" "fmt" "strings" "time" @@ -54,29 +55,79 @@ type AuditPlanSQLV2 struct { Schema string `json:"schema" gorm:"type:varchar(512);not null"` } +type BlacklistFilterType string + const ( - FilterTypeSQL string = "SQL" - FilterTypeIP string = "IP" - FilterTypeCIDR string = "CIDR" - FilterTypeHost string = "HOST" + FilterTypeSQL BlacklistFilterType = "sql" + FilterTypeFpSQL BlacklistFilterType = "fp_sql" + FilterTypeIP BlacklistFilterType = "ip" + FilterTypeCIDR BlacklistFilterType = "cidr" + FilterTypeHost BlacklistFilterType = "host" + FilterTypeInstance BlacklistFilterType = "instance" ) type BlackListAuditPlanSQL struct { Model - FilterContent string `json:"filter_content" gorm:"type:varchar(512);not null;"` - FilterType string `json:"filter_type" gorm:"type:enum('SQL','IP','CIDR','HOST');default:'SQL';not null;"` + ProjectId ProjectUID `gorm:"index; not null"` + FilterContent string `json:"filter_content" gorm:"type:varchar(3000);not null;"` + Desc string `json:"desc" gorm:"type:varchar(512)"` + FilterType BlacklistFilterType `json:"filter_type" gorm:"type:enum('sql','fp_sql','ip','cidr','host','instance');default:'SQL';not null;"` + MatchedCount uint `json:"matched_count" gorm:"default:0"` + LastMatchTime *time.Time `json:"last_match_time"` } func (a BlackListAuditPlanSQL) TableName() string { return "black_list_audit_plan_sqls" } -func (s *Storage) GetBlackListAuditPlanSQLs() ([]*BlackListAuditPlanSQL, error) { +func (s *Storage) GetBlacklistByID(projectID ProjectUID, id string) (*BlackListAuditPlanSQL, bool, error) { + bl := &BlackListAuditPlanSQL{} + err := s.db.Model(BlackListAuditPlanSQL{}).Where("project_id = ? AND id = ?", projectID, id).First(bl).Error + if e.Is(err, gorm.ErrRecordNotFound) { + return bl, false, nil + } + return bl, true, errors.New(errors.ConnectStorageError, err) +} + +func (s *Storage) GetBlackListByProjectID(projectID ProjectUID) ([]*BlackListAuditPlanSQL, error) { var blackListAPS []*BlackListAuditPlanSQL - err := s.db.Model(BlackListAuditPlanSQL{}).Find(&blackListAPS).Error + err := s.db.Model(BlackListAuditPlanSQL{}).Where("project_id = ?", projectID).Find(&blackListAPS).Error return blackListAPS, errors.New(errors.ConnectStorageError, err) } +func (s *Storage) GetBlacklistList(projectID ProjectUID, FilterType BlacklistFilterType, fuzzySearchContent string, pageIndex, pageSize uint32) ([]*BlackListAuditPlanSQL, uint64, error) { + var count int64 + var blackListAPS []*BlackListAuditPlanSQL + query := s.db.Model(BlackListAuditPlanSQL{}).Where("project_id = ?", projectID) + if FilterType != "" { + query = query.Where("filter_type = ?", FilterType) + } + if fuzzySearchContent != "" { + query = query.Where("filter_content LIKE ?", "%"+fuzzySearchContent+"%") + } + err := query.Count(&count).Error + if err != nil { + return blackListAPS, uint64(count), errors.New(errors.ConnectStorageError, err) + } + + if count == 0 { + return blackListAPS, uint64(count), errors.New(errors.ConnectStorageError, err) + } + + err = query.Offset(int((pageIndex - 1) * pageSize)).Limit(int(pageSize)).Order("id desc").Find(&blackListAPS).Error + return blackListAPS, uint64(count), errors.New(errors.ConnectStorageError, err) +} + +func (s *Storage) BatchUpdateBlackListCount(IdList []uint, matchedCount uint, lastMatchTime time.Time) error { + m := map[string]interface{}{ + "matched_count": gorm.Expr("matched_count + ?", matchedCount), + "last_match_time": lastMatchTime, + } + + err := s.db.Model(BlackListAuditPlanSQL{}).Where("id in (?)", IdList).Updates(m).Error + return errors.New(errors.ConnectStorageError, err) +} + func (a AuditPlanSQLV2) TableName() string { return "audit_plan_sqls_v2" } diff --git a/sqle/model/instance_audit_plan.go b/sqle/model/instance_audit_plan.go index b92bd547ee..2b5f22f723 100644 --- a/sqle/model/instance_audit_plan.go +++ b/sqle/model/instance_audit_plan.go @@ -50,6 +50,13 @@ type AuditPlanDetail struct { Instance *Instance `gorm:"-"` } +func (a AuditPlanDetail) GetInstanceName() string { + if a.Instance == nil { + return "" + } + return a.Instance.Name +} + func (s *Storage) ListActiveAuditPlanDetail() ([]*AuditPlanDetail, error) { var aps []*AuditPlanDetail err := s.db.Model(AuditPlanV2{}).Joins("JOIN instance_audit_plans ON instance_audit_plans.id = audit_plans_v2.instance_audit_plan_id"). @@ -95,6 +102,7 @@ func (s *Storage) getAuditPlanDetailByID(id uint, status string) (*AuditPlanDeta if ap == nil { return nil, false, nil } + return ap, true, nil } diff --git a/sqle/model/sql_whitelist.go b/sqle/model/sql_whitelist.go index df86e53fa2..e3cd009ab5 100644 --- a/sqle/model/sql_whitelist.go +++ b/sqle/model/sql_whitelist.go @@ -2,6 +2,7 @@ package model import ( "strings" + "time" "github.com/actiontech/sqle/sqle/errors" @@ -21,8 +22,10 @@ type SqlWhitelist struct { CapitalizedValue string `json:"-" gorm:"-"` Desc string `json:"desc" gorm:"type:varchar(255)"` // MessageDigest deprecated after 1.1.0, keep it for compatibility. - MessageDigest string `json:"message_digest" gorm:"type:char(32) not null comment 'md5 data';" ` - MatchType string `json:"match_type" gorm:"default:\"exact_match\""` + MessageDigest string `json:"message_digest" gorm:"type:char(32) not null comment 'md5 data';" ` + MatchType string `json:"match_type" gorm:"default:\"exact_match\""` + MatchedCount int `json:"matched_count" gorm:"default:0"` + LastMatchedTime *time.Time `json:"last_matched_time"` } // BeforeSave is a hook implement gorm model before exec create @@ -84,11 +87,20 @@ func (s *Storage) GetSqlWhitelistByIdAndProjectUID(sqlWhiteId string, projectUID // return sqlWhitelist, count, errors.New(errors.ConnectStorageError, err) // } -func (s *Storage) GetSqlWhitelistByProjectUID(pageIndex, pageSize uint32, projectUID ProjectUID) ([]SqlWhitelist, int64, error) { +func (s *Storage) GetSqlWhitelistByProjectUID(pageIndex, pageSize uint32, projectUID ProjectUID, fuzzyValue, matchType *string) ([]SqlWhitelist, int64, error) { var count int64 sqlWhitelist := []SqlWhitelist{} query := s.db.Table("sql_whitelist"). Where("project_id = ?", projectUID).Where("deleted_at IS NULL") + + if fuzzyValue != nil { + query = query.Where("value LIKE ?", "%"+*fuzzyValue+"%") + } + + if matchType != nil { + query = query.Where("match_type = ?", *matchType) + } + if pageSize == 0 { err := query.Order("id desc").Find(&sqlWhitelist).Count(&count).Error return sqlWhitelist, count, errors.New(errors.ConnectStorageError, err) @@ -109,6 +121,16 @@ func (s *Storage) GetSqlWhitelistByProjectId(projectId string) ([]SqlWhitelist, return sqlWhitelist, errors.New(errors.ConnectStorageError, err) } +func (s *Storage) UpdateSqlWhitelistMatchedInfo(id uint, count int, lastMatchedTime time.Time) error { + m := map[string]interface{}{ + "matched_count": gorm.Expr("matched_count + ?", count), + "last_matched_time": lastMatchedTime, + } + + err := s.db.Model(&SqlWhitelist{}).Where("sql_whitelist.id = ?", id).UpdateColumns(m).Error + return errors.New(errors.ConnectStorageError, err) +} + // func (s *Storage) GetSqlWhitelistTotalByProjectName(projectName string) (uint64, error) { // var count uint64 // err := s.db. diff --git a/sqle/server/audit.go b/sqle/server/audit.go index b5946e17cd..605465326a 100644 --- a/sqle/server/audit.go +++ b/sqle/server/audit.go @@ -6,6 +6,7 @@ import ( "math" "runtime/debug" "strings" + "time" "github.com/actiontech/sqle/sqle/driver" "github.com/actiontech/sqle/sqle/driver/mysql/session" @@ -158,6 +159,7 @@ func hookAudit(l *logrus.Entry, task *model.Task, p driver.Plugin, hook AuditHoo return err } var whitelistMatch bool + var matchedWhitelistID uint for _, wl := range whitelist { if wl.MatchType == model.SQLWhitelistFPMatch { wlNode, err := parse(l, p, wl.Value) @@ -165,21 +167,26 @@ func hookAudit(l *logrus.Entry, task *model.Task, p driver.Plugin, hook AuditHoo l.Errorf("parse whitelist sql error: %v,please check the accuracy of whitelist SQL: %s", err, wl.Value) } if node.Fingerprint == wlNode.Fingerprint { + matchedWhitelistID = wl.ID whitelistMatch = true } } else { if wl.CapitalizedValue == strings.ToUpper(node.Text) { + matchedWhitelistID = wl.ID whitelistMatch = true } } } if whitelistMatch { result := driverV2.NewAuditResults() - result.Add(driverV2.RuleLevelNormal, "", "白名单") + result.Add(driverV2.RuleLevelNormal, "", "审核SQL例外") executeSQL.AuditStatus = model.SQLAuditStatusFinished executeSQL.AuditLevel = string(result.Level()) executeSQL.AuditFingerprint = utils.Md5String(string(append([]byte(result.Message()), []byte(node.Fingerprint)...))) appendExecuteSqlResults(executeSQL, result) + if err := st.UpdateSqlWhitelistMatchedInfo(matchedWhitelistID, 1, time.Now()); err != nil { + l.Errorf("update sql whitelist matched info error: %v", err) + } } else { auditSqls = append(auditSqls, executeSQL) sqls = append(sqls, executeSQL.Content) diff --git a/sqle/server/auditplan/task_wrap.go b/sqle/server/auditplan/task_wrap.go index b84e720f25..f595f45a99 100644 --- a/sqle/server/auditplan/task_wrap.go +++ b/sqle/server/auditplan/task_wrap.go @@ -1,10 +1,17 @@ package auditplan import ( + "context" + e "errors" + "net" + "regexp" "sync" "time" + "github.com/actiontech/sqle/sqle/dms" + "github.com/actiontech/sqle/sqle/log" "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/utils" "github.com/sirupsen/logrus" ) @@ -173,10 +180,11 @@ func (at *TaskWrapper) FullSyncSQLs(sqls []*SQL) error { sqlList = append(sqlList, ConvertSQLV2ToMangerSQLQueue(sql)) } - err := at.persist.PushSQLToManagerSQLQueue(sqlList) + err := at.pushSQLToManagerSQLQueue(sqlList) if err != nil { at.logger.Errorf("push sql to manager sql queue failed, error : %v", err) } + return err } @@ -233,10 +241,10 @@ func (at *TaskWrapper) extractSQL() { for _, sql := range sqls { sqlQueues = append(sqlQueues, ConvertSQLV2ToMangerSQLQueue(sql)) } - err = at.persist.PushSQLToManagerSQLQueue(sqlQueues) + + err = at.pushSQLToManagerSQLQueue(sqlQueues) if err != nil { at.logger.Errorf("push sql to manager sql queue failed, error : %v", err) - return } } @@ -259,3 +267,258 @@ func (at *TaskWrapper) loop(cancel chan struct{}, interval time.Duration) { } } } + +func (at *TaskWrapper) pushSQLToManagerSQLQueue(sqlList []*model.SQLManageQueue) error { + if len(sqlList) == 0 { + return nil + } + + matchedCount, SqlQueueList, err := at.filterSqlManageQueue(sqlList) + if err != nil { + return err + } + + err = at.updateBlacklistInfo(matchedCount) + if err != nil { + return err + } + + err = at.persist.PushSQLToManagerSQLQueue(SqlQueueList) + if err != nil { + return err + } + + return nil +} + +func (at *TaskWrapper) filterSqlManageQueue(sqlList []*model.SQLManageQueue) (map[uint]uint, []*model.SQLManageQueue, error) { + blackListMap := make(map[string][]*model.BlackListAuditPlanSQL) + instanceMap := make(map[string]string) + var err error + matchedCount := make(map[uint]uint) + SqlQueueList := make([]*model.SQLManageQueue, 0) + for _, sql := range sqlList { + blacklist, ok := blackListMap[sql.ProjectId] + if !ok { + blacklist, err = at.persist.GetBlackListByProjectID(model.ProjectUID(sql.ProjectId)) + if err != nil { + return nil, nil, err + } + blackListMap[sql.ProjectId] = blacklist + } + + instName, ok := instanceMap[sql.InstanceID] + if !ok { + instance, exist, err := dms.GetInstancesById(context.TODO(), sql.InstanceID) + if err != nil { + return nil, nil, err + } + if !exist { + return nil, nil, e.New("instance not exist") + } + instName = instance.Name + instanceMap[sql.InstanceID] = instName + } + + matchedID, isInBlacklist := filterSQLsByBlackList(sql.EndPoint, sql.SqlText, sql.SqlFingerprint, instName, blacklist) + if isInBlacklist { + matchedCount[matchedID]++ + continue + } + + SqlQueueList = append(SqlQueueList, sql) + } + + return matchedCount, SqlQueueList, nil +} + +func filterSQLsByBlackList(endpoint, sqlText, sqlFp, instName string, blacklist []*model.BlackListAuditPlanSQL) (uint, bool) { + if len(blacklist) == 0 { + return 0, false + } + + filter := ConvertToBlackFilter(blacklist) + + matchedID, hasEndpointInBlacklist := filter.HasEndpointInBlackList([]string{endpoint}) + if hasEndpointInBlacklist { + return matchedID, true + } + + matchedID, isSqlInBlackList := filter.IsSqlInBlackList(sqlText) + if isSqlInBlackList { + return matchedID, true + } + + matchedID, isFpInBlackList := filter.IsFpInBlackList(sqlFp) + if isFpInBlackList { + return matchedID, true + } + + matchedID, isInstNameInBlackList := filter.IsInstNameInBlackList(instName) + if isInstNameInBlackList { + return matchedID, true + } + + return 0, false +} + +func (at *TaskWrapper) updateBlacklistInfo(matchedCount map[uint]uint) error { + m := make(map[uint] /*count*/ []uint /*blacklist id list*/) + for id, count := range matchedCount { + m[count] = append(m[count], id) + } + + lastMatchedTime := time.Now() + for count, idList := range m { + err := at.persist.BatchUpdateBlackListCount(idList, count, lastMatchedTime) + if err != nil { + return err + } + } + + return nil +} + +func ConvertToBlackFilter(blackList []*model.BlackListAuditPlanSQL) *BlackFilter { + var blackFilter BlackFilter + for _, filter := range blackList { + switch filter.FilterType { + case model.FilterTypeSQL: + blackFilter.BlackSqlList = append(blackFilter.BlackSqlList, BlackSqlList{ + ID: filter.ID, + Regexp: utils.FullFuzzySearchRegexp(filter.FilterContent), + }) + case model.FilterTypeFpSQL: + blackFilter.BlackFpList = append(blackFilter.BlackFpList, BlackFpList{ + ID: filter.ID, + Regexp: utils.FullFuzzySearchRegexp(filter.FilterContent), + }) + case model.FilterTypeHost: + blackFilter.BlackHostList = append(blackFilter.BlackHostList, BlackHostList{ + ID: filter.ID, + Regexp: utils.FullFuzzySearchRegexp(filter.FilterContent), + }) + case model.FilterTypeIP: + ip := net.ParseIP(filter.FilterContent) + if ip == nil { + log.Logger().Errorf("wrong ip in black list,ip:%s", filter.FilterContent) + continue + } + blackFilter.BlackIpList = append(blackFilter.BlackIpList, BlackIpList{ + ID: filter.ID, + Ip: ip, + }) + case model.FilterTypeCIDR: + _, cidr, err := net.ParseCIDR(filter.FilterContent) + if err != nil { + log.Logger().Errorf("wrong cidr in black list,cidr:%s,err:%v", filter.FilterContent, err) + continue + } + blackFilter.BlackCidrList = append(blackFilter.BlackCidrList, BlackCidrList{ + ID: filter.ID, + Cidr: cidr, + }) + case model.FilterTypeInstance: + blackFilter.BlackInstList = append(blackFilter.BlackInstList, BlackInstList{ + ID: filter.ID, + InstName: filter.FilterContent, + }) + } + } + return &blackFilter +} + +type BlackFilter struct { + BlackSqlList []BlackSqlList + BlackFpList []BlackFpList + BlackIpList []BlackIpList + BlackHostList []BlackHostList + BlackCidrList []BlackCidrList + BlackInstList []BlackInstList +} + +type BlackSqlList struct { + ID uint + Regexp *regexp.Regexp +} + +type BlackFpList struct { + ID uint + Regexp *regexp.Regexp +} + +type BlackIpList struct { + ID uint + Ip net.IP +} + +type BlackHostList struct { + ID uint + Regexp *regexp.Regexp +} + +type BlackCidrList struct { + ID uint + Cidr *net.IPNet +} + +type BlackInstList struct { + ID uint + InstName string +} + +func (f BlackFilter) IsSqlInBlackList(checkSql string) (uint, bool) { + for _, blackSql := range f.BlackSqlList { + if blackSql.Regexp.MatchString(checkSql) { + return blackSql.ID, true + } + } + return 0, false +} + +func (f BlackFilter) IsFpInBlackList(fp string) (uint, bool) { + for _, blackFp := range f.BlackFpList { + if blackFp.Regexp.MatchString(fp) { + return blackFp.ID, true + } + } + return 0, false +} + +func (f BlackFilter) IsInstNameInBlackList(instName string) (uint, bool) { + for _, blackInstName := range f.BlackInstList { + if blackInstName.InstName == instName { + return blackInstName.ID, true + } + } + return 0, false +} + +// 输入一组ip若其中有一个ip在黑名单中则返回true +func (f BlackFilter) HasEndpointInBlackList(checkIps []string) (uint, bool) { + var checkNetIp net.IP + for _, checkIp := range checkIps { + checkNetIp = net.ParseIP(checkIp) + if checkNetIp == nil { + // 无法解析IP,可能是域名,需要正则匹配 + for _, blackHost := range f.BlackHostList { + if blackHost.Regexp.MatchString(checkIp) { + return blackHost.ID, true + } + } + } else { + for _, blackIp := range f.BlackIpList { + if blackIp.Ip.Equal(checkNetIp) { + return blackIp.ID, true + } + } + for _, blackCidr := range f.BlackCidrList { + if blackCidr.Cidr.Contains(checkNetIp) { + return blackCidr.ID, true + } + } + } + } + + return 0, false +} diff --git a/sqle/api/controller/v1/audit_plan_test.go b/sqle/server/auditplan/task_wrap_test.go similarity index 65% rename from sqle/api/controller/v1/audit_plan_test.go rename to sqle/server/auditplan/task_wrap_test.go index f6f13ca493..edb6d3b82a 100644 --- a/sqle/api/controller/v1/audit_plan_test.go +++ b/sqle/server/auditplan/task_wrap_test.go @@ -1,23 +1,22 @@ -package v1_test +package auditplan import ( "testing" - v1 "github.com/actiontech/sqle/sqle/api/controller/v1" "github.com/actiontech/sqle/sqle/model" ) func TestIsSqlInBlackList(t *testing.T) { - filter := v1.ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ + filter := ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ { FilterContent: "SELECT", - FilterType: "SQL", + FilterType: "sql", }, { FilterContent: "table_1", - FilterType: "SQL", - },{ + FilterType: "sql", + }, { FilterContent: "ignored_service", - FilterType: "SQL", + FilterType: "sql", }, }) @@ -30,7 +29,7 @@ func TestIsSqlInBlackList(t *testing.T) { `/* this is a comment, Service: ignored_service */ update * from table_ignored where id < 123;`, } for _, matchSql := range matchSqls { - if !filter.IsSqlInBlackList(matchSql) { + if _, isSqlInBlackList := filter.IsSqlInBlackList(matchSql); !isSqlInBlackList { t.Error("Expected SQL to match blacklist") } } @@ -42,20 +41,20 @@ func TestIsSqlInBlackList(t *testing.T) { service */ update * from table_ignored where id < 123;`, } for _, notMatchSql := range notMatchSqls { - if filter.IsSqlInBlackList(notMatchSql) { + if _, isSqlInBlackList := filter.IsSqlInBlackList(notMatchSql); isSqlInBlackList { t.Error("Did not expect SQL to match blacklist") } } } func TestIsIpInBlackList(t *testing.T) { - filter := v1.ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ + filter := ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ { FilterContent: "192.168.1.23", - FilterType: "IP", + FilterType: "ip", }, { FilterContent: "10.0.5.67", - FilterType: "IP", + FilterType: "ip", }, }) @@ -64,7 +63,7 @@ func TestIsIpInBlackList(t *testing.T) { "192.168.1.23", } for _, matchIp := range matchIps { - if !filter.HasEndpointInBlackList([]string{matchIp}) { + if _, hasEndpointInBlackList := filter.HasEndpointInBlackList([]string{matchIp}); !hasEndpointInBlackList { t.Error("Expected Ip to match blacklist") } } @@ -75,20 +74,20 @@ func TestIsIpInBlackList(t *testing.T) { "50.67.89.12", } for _, notMatchIp := range notMatchIps { - if filter.HasEndpointInBlackList([]string{notMatchIp}) { + if _, hasEndpointInBlackList := filter.HasEndpointInBlackList([]string{notMatchIp}); hasEndpointInBlackList { t.Error("Did not expect Ip to match blacklist") } } } func TestIsCidrInBlackList(t *testing.T) { - filter := v1.ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ + filter := ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ { FilterContent: "192.168.0.0/24", - FilterType: "CIDR", + FilterType: "cidr", }, { FilterContent: "10.100.0.0/16", - FilterType: "CIDR", + FilterType: "cidr", }, }) @@ -99,7 +98,7 @@ func TestIsCidrInBlackList(t *testing.T) { "192.168.0.45", } for _, matchIp := range matchIps { - if !filter.HasEndpointInBlackList([]string{matchIp}) { + if _, hasEndpointInBlackList := filter.HasEndpointInBlackList([]string{matchIp}); !hasEndpointInBlackList { t.Error("Expected CIDR to match blacklist") } } @@ -112,20 +111,20 @@ func TestIsCidrInBlackList(t *testing.T) { "172.30.30.45", } for _, notMatchIp := range notMatchIps { - if filter.HasEndpointInBlackList([]string{notMatchIp}) { + if _, hasEndpointInBlackList := filter.HasEndpointInBlackList([]string{notMatchIp}); hasEndpointInBlackList { t.Error("Did not expect CIDR to match blacklist") } } } func TestIsHostInBlackList(t *testing.T) { - filter := v1.ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ + filter := ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ { FilterContent: "host", - FilterType: "HOST", + FilterType: "host", }, { FilterContent: "some_site", - FilterType: "HOST", + FilterType: "host", }, }) @@ -138,7 +137,7 @@ func TestIsHostInBlackList(t *testing.T) { } for _, matchHost := range matchHosts { - if !filter.HasEndpointInBlackList([]string{matchHost}) { + if _, hasEndpointInBlackList := filter.HasEndpointInBlackList([]string{matchHost}); !hasEndpointInBlackList { t.Error("Expected HOST to match blacklist") } } @@ -148,7 +147,7 @@ func TestIsHostInBlackList(t *testing.T) { "any_other_site/local", } for _, noMatchHost := range notMatchHosts { - if filter.HasEndpointInBlackList([]string{noMatchHost}) { + if _, hasEndpointInBlackList := filter.HasEndpointInBlackList([]string{noMatchHost}); hasEndpointInBlackList { t.Error("Did not expect HOST to match blacklist") } } diff --git a/sqle/server/sqled_test.go b/sqle/server/sqled_test.go index 489ec75582..4042006b85 100644 --- a/sqle/server/sqled_test.go +++ b/sqle/server/sqled_test.go @@ -173,6 +173,11 @@ func Test_action_audit_UpdateTask(t *testing.T) { WithArgs(""). WillReturnRows(sqlmock.NewRows([]string{"value", "match_type"}).AddRow(whitelist.Value, whitelist.MatchType)) + mock.ExpectBegin() + mock.ExpectExec(regexp.QuoteMeta("UPDATE `sql_whitelist` SET `last_matched_time`=?,`matched_count`=matched_count + ? WHERE sql_whitelist.id = ? AND `sql_whitelist`.`deleted_at` IS NULL")). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() mock.ExpectExec(regexp.QuoteMeta("INSERT INTO `execute_sql_detail`")). // WithArgs(model.MockTime, model.MockTime, nil, 0, 0, act.task.ExecuteSQLs[0].Content, "", "", 0, "", 0, 0, "", "", "", 0, "", model.SQLAuditStatusFinished, `[{"level":"normal","message":"白名单","rule_name":""}]`, "2882fdbb7d5bcda7b49ea0803493467e", "normal").