Skip to content

Commit

Permalink
Merge pull request #1914 from actiontech/fix-issue1913
Browse files Browse the repository at this point in the history
fix issue 1913
  • Loading branch information
sjjian authored Oct 18, 2023
2 parents 4d440ea + add3ae5 commit 13eae2c
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 12 deletions.
44 changes: 44 additions & 0 deletions sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3954,6 +3954,50 @@ func Test_DMLCheckInQueryLimit(t *testing.T) {
runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"update exist_tb_1 set v1 = 'v1_next' where id in (1,2,3,4,5,6,7)",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4))",
newTestResult())

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"delete from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4))",
newTestResult())

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"delete from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"update exist_tb_1 set v1 = 'v1_next' where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"update exist_tb_1 set v1 = 'v1_next' where id in (select id from exist_tb_1 where id in (1,2,3,4))",
newTestResult())

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c'))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c', 'd', 'e', 'f'))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue).addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"delete from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c', 'd', 'e', 'f'))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue).addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"update exist_tb_1 set v1 = 'v1_next' where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c', 'd', 'e', 'f'))",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue).addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select 1 in (1,2,3,4,5,6);",
newTestResult())
}

func TestCheckIndexOption(t *testing.T) {
Expand Down
29 changes: 17 additions & 12 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2801,24 +2801,29 @@ func hasDefaultValueCurrentTimeStamp(options []*ast.ColumnOption) bool {
}

func checkInQueryLimit(input *RuleHandlerInput) error {
where := getWhereExpr(input.Node)
if where == nil {
dmlNode, ok := input.Node.(ast.DMLNode)
if !ok {
return nil
}

whereVisitor := &util.WhereVisitor{}
dmlNode.Accept(whereVisitor)
paramThresholdNumber := input.Rule.Params.GetParam(DefaultSingleParamKeyName).Int()
util.ScanWhereStmt(func(expr ast.ExprNode) bool {
switch stmt := expr.(type) {
case *ast.PatternInExpr:
inQueryParamActualNumber := len(stmt.List)
if inQueryParamActualNumber > paramThresholdNumber {
addResult(input.Res, input.Rule, DMLCheckInQueryNumber, inQueryParamActualNumber, paramThresholdNumber)

for _, whereExpr := range whereVisitor.WhereList {
util.ScanWhereStmt(func(expr ast.ExprNode) bool {
switch stmt := expr.(type) {
case *ast.PatternInExpr:
inQueryParamActualNumber := len(stmt.List)
if inQueryParamActualNumber > paramThresholdNumber {
addResult(input.Res, input.Rule, DMLCheckInQueryNumber, inQueryParamActualNumber, paramThresholdNumber)
}
return true
}
return true
}

return false
}, where)
return false
}, whereExpr)
}

return nil
}
Expand Down
29 changes: 29 additions & 0 deletions sqle/driver/mysql/util/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,32 @@ func (v *ColumnNameVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool)
func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

type WhereVisitor struct {
WhereList []ast.ExprNode
}

func (v *WhereVisitor) append(where ast.ExprNode) {
if where != nil {
v.WhereList = append(v.WhereList, where)
}
}

func (v *WhereVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch stmt := in.(type) {
case *ast.SelectStmt:
if stmt.From == nil { //If from is null skip check. EX: select 1;select version
return in, false
}
v.append(stmt.Where)
case *ast.UpdateStmt:
v.append(stmt.Where)
case *ast.DeleteStmt:
v.append(stmt.Where)
}
return in, false
}

func (v *WhereVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

0 comments on commit 13eae2c

Please sign in to comment.