From 719df6a30a1895e78cee29973f2a349629fc05ca Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Mon, 16 Oct 2023 05:13:52 +0000 Subject: [PATCH 1/3] use inexpr visitor --- sqle/driver/mysql/audit_test.go | 29 +++++++++++++++++++++++++++++ sqle/driver/mysql/rule/rule.go | 24 +++++++++++------------- sqle/driver/mysql/util/visitor.go | 16 ++++++++++++++++ 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/sqle/driver/mysql/audit_test.go b/sqle/driver/mysql/audit_test.go index 85d79c3620..202e10fa26 100644 --- a/sqle/driver/mysql/audit_test.go +++ b/sqle/driver/mysql/audit_test.go @@ -3954,6 +3954,35 @@ func Test_DMLCheckInQueryLimit(t *testing.T) { runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), "update exist_tb_1 set v1 = 'v1_next' where id in (1,2,3,4,5,6,7)", newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4))", + newTestResult()) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "delete from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4))", + newTestResult()) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "delete from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "update exist_tb_1 set v1 = 'v1_next' where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "update exist_tb_1 set v1 = 'v1_next' where id in (select id from exist_tb_1 where id in (1,2,3,4))", + newTestResult()) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c'))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + } func TestCheckIndexOption(t *testing.T) { diff --git a/sqle/driver/mysql/rule/rule.go b/sqle/driver/mysql/rule/rule.go index 00b0539794..7d9b7d4af6 100644 --- a/sqle/driver/mysql/rule/rule.go +++ b/sqle/driver/mysql/rule/rule.go @@ -2801,24 +2801,22 @@ func hasDefaultValueCurrentTimeStamp(options []*ast.ColumnOption) bool { } func checkInQueryLimit(input *RuleHandlerInput) error { - where := getWhereExpr(input.Node) - if where == nil { + dmlNode, ok := input.Node.(ast.DMLNode) + if !ok { return nil } + inVisitor := &util.PatternInVisitor{} + dmlNode.Accept(inVisitor) paramThresholdNumber := input.Rule.Params.GetParam(DefaultSingleParamKeyName).Int() - util.ScanWhereStmt(func(expr ast.ExprNode) bool { - switch stmt := expr.(type) { - case *ast.PatternInExpr: - inQueryParamActualNumber := len(stmt.List) - if inQueryParamActualNumber > paramThresholdNumber { - addResult(input.Res, input.Rule, DMLCheckInQueryNumber, inQueryParamActualNumber, paramThresholdNumber) - } - return true - } - return false - }, where) + for _, inExpr := range inVisitor.PatternInList { + inQueryParamActualNumber := len(inExpr.List) + if inQueryParamActualNumber > paramThresholdNumber { + addResult(input.Res, input.Rule, DMLCheckInQueryNumber, inQueryParamActualNumber, paramThresholdNumber) + return nil + } + } return nil } diff --git a/sqle/driver/mysql/util/visitor.go b/sqle/driver/mysql/util/visitor.go index e31bd403d5..262f2c9e7c 100644 --- a/sqle/driver/mysql/util/visitor.go +++ b/sqle/driver/mysql/util/visitor.go @@ -243,3 +243,19 @@ func (v *ColumnNameVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } + +type PatternInVisitor struct { + PatternInList []*ast.PatternInExpr +} + +func (v *PatternInVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch stmt := in.(type) { + case *ast.PatternInExpr: + v.PatternInList = append(v.PatternInList, stmt) + } + return in, false +} + +func (v *PatternInVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} From 949fdb5d593f1009704535fb553aafa2d3ab5c24 Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Wed, 18 Oct 2023 02:05:51 +0000 Subject: [PATCH 2/3] use wherevisitor instead of invisitor --- sqle/driver/mysql/audit_test.go | 3 +++ sqle/driver/mysql/rule/rule.go | 23 +++++++++++++++-------- sqle/driver/mysql/util/visitor.go | 25 +++++++++++++++++++------ 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/sqle/driver/mysql/audit_test.go b/sqle/driver/mysql/audit_test.go index 202e10fa26..c9657de7be 100644 --- a/sqle/driver/mysql/audit_test.go +++ b/sqle/driver/mysql/audit_test.go @@ -3983,6 +3983,9 @@ func Test_DMLCheckInQueryLimit(t *testing.T) { "select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c'))", newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "select 1 in (1,2,3,4,5,6);", + newTestResult()) } func TestCheckIndexOption(t *testing.T) { diff --git a/sqle/driver/mysql/rule/rule.go b/sqle/driver/mysql/rule/rule.go index 7d9b7d4af6..f91ad1b862 100644 --- a/sqle/driver/mysql/rule/rule.go +++ b/sqle/driver/mysql/rule/rule.go @@ -2806,16 +2806,23 @@ func checkInQueryLimit(input *RuleHandlerInput) error { return nil } - inVisitor := &util.PatternInVisitor{} - dmlNode.Accept(inVisitor) + whereVisitor := &util.WhereVisitor{} + dmlNode.Accept(whereVisitor) paramThresholdNumber := input.Rule.Params.GetParam(DefaultSingleParamKeyName).Int() - for _, inExpr := range inVisitor.PatternInList { - inQueryParamActualNumber := len(inExpr.List) - if inQueryParamActualNumber > paramThresholdNumber { - addResult(input.Res, input.Rule, DMLCheckInQueryNumber, inQueryParamActualNumber, paramThresholdNumber) - return nil - } + for _, whereExpr := range whereVisitor.WhereList { + util.ScanWhereStmt(func(expr ast.ExprNode) bool { + switch stmt := expr.(type) { + case *ast.PatternInExpr: + inQueryParamActualNumber := len(stmt.List) + if inQueryParamActualNumber > paramThresholdNumber { + addResult(input.Res, input.Rule, DMLCheckInQueryNumber, inQueryParamActualNumber, paramThresholdNumber) + } + return true + } + + return false + }, whereExpr) } return nil diff --git a/sqle/driver/mysql/util/visitor.go b/sqle/driver/mysql/util/visitor.go index 262f2c9e7c..9cd65b4bb5 100644 --- a/sqle/driver/mysql/util/visitor.go +++ b/sqle/driver/mysql/util/visitor.go @@ -244,18 +244,31 @@ func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } -type PatternInVisitor struct { - PatternInList []*ast.PatternInExpr +type WhereVisitor struct { + WhereList []ast.ExprNode } -func (v *PatternInVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { +func (v *WhereVisitor) append(where ast.ExprNode) { + if where != nil { + v.WhereList = append(v.WhereList, where) + } +} + +func (v *WhereVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { switch stmt := in.(type) { - case *ast.PatternInExpr: - v.PatternInList = append(v.PatternInList, stmt) + case *ast.SelectStmt: + if stmt.From == nil { //If from is null skip check. EX: select 1;select version + return in, false + } + v.append(stmt.Where) + case *ast.UpdateStmt: + v.append(stmt.Where) + case *ast.DeleteStmt: + v.append(stmt.Where) } return in, false } -func (v *PatternInVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { +func (v *WhereVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } From add3ae5eacab0282ebb390ada046946882fe690a Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Wed, 18 Oct 2023 02:12:33 +0000 Subject: [PATCH 3/3] improve unit test --- sqle/driver/mysql/audit_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sqle/driver/mysql/audit_test.go b/sqle/driver/mysql/audit_test.go index c9657de7be..d8bd33a028 100644 --- a/sqle/driver/mysql/audit_test.go +++ b/sqle/driver/mysql/audit_test.go @@ -3983,6 +3983,18 @@ func Test_DMLCheckInQueryLimit(t *testing.T) { "select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c'))", newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue)) + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "select * from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c', 'd', 'e', 'f'))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue).addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue)) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "delete from exist_tb_1 where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c', 'd', 'e', 'f'))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue).addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue)) + + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "update exist_tb_1 set v1 = 'v1_next' where id in (select id from exist_tb_1 where id in (1,2,3,4,5,6,7) and v1 in ('a', 'b', 'c', 'd', 'e', 'f'))", + newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 7, paramValue).addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue)) + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), "select 1 in (1,2,3,4,5,6);", newTestResult())