From 4bcc9d322228d10341efd7eb12ae2d55df6b4719 Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Thu, 9 Nov 2023 02:20:19 +0000 Subject: [PATCH] judge where is nil --- sqle/driver/mysql/audit_offline_test.go | 22 +++++++++++++++++++++- sqle/driver/mysql/audit_test.go | 6 +++++- sqle/driver/mysql/rule/rule.go | 7 +++++++ sqle/driver/mysql/util/visitor.go | 5 ++++- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/sqle/driver/mysql/audit_offline_test.go b/sqle/driver/mysql/audit_offline_test.go index 097402c6e4..ab65f8e335 100644 --- a/sqle/driver/mysql/audit_offline_test.go +++ b/sqle/driver/mysql/audit_offline_test.go @@ -101,6 +101,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) { "select id from exist_db.exist_tb_1 where id > 1;", noResult, }, + { + "select_from: has where condition", + "select id from exist_db.exist_tb_1;", + whereIsInvalid, + }, { "select_from: no where condition(1)", "select id from exist_db.exist_tb_1;", @@ -129,7 +134,7 @@ func TestCheckWhereInvalidOffline(t *testing.T) { { "select_from: no where condition(6)", "select id from (select * from exist_db.exist_tb_1 where exist_tb_1.id>1) t;", - noResult, + whereIsInvalid, }, // UPDATE { @@ -212,6 +217,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) { "select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where 1=1);", whereIsInvalid, }, + { + "use exists", + "select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2);", + whereIsInvalid, + }, { "use exists", "select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where exists (select 1 from exist_db.exist_db_3 t3 where t1.id=t2.id and t2.id=t3.id));", @@ -222,6 +232,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) { "select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where exists (select 1 from exist_db.exist_db_3 t3 where 1=1));", whereIsInvalid, }, + { + "use exists", + "select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where exists (select 1 from exist_db.exist_db_3));", + whereIsInvalid, + }, { "use not exists", "select * from exist_db.exist_tb_1 t1 where not exists (select 1 from exist_db.exist_tb_2 t2 where t1.id=t2.id);", @@ -232,6 +247,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) { "select * from exist_db.exist_tb_1 t1 where not exists (select 1 from exist_db.exist_tb_2 t2 where 1=1);", whereIsInvalid, }, + { + "use not exists", + "select * from exist_db.exist_tb_1 t1 where not exists (select 1 from exist_db.exist_tb_2);", + whereIsInvalid, + }, } offlineInspect := DefaultMysqlInspectOffline() for _, testCase := range testCases { diff --git a/sqle/driver/mysql/audit_test.go b/sqle/driver/mysql/audit_test.go index 5b6d28edd4..93382b3414 100644 --- a/sqle/driver/mysql/audit_test.go +++ b/sqle/driver/mysql/audit_test.go @@ -1197,7 +1197,7 @@ func TestCheckWhereInvalid(t *testing.T) { runDefaultRulesInspectCase(t, "select_count: has where condition(2)", DefaultMysqlInspect(), "select id from (select * from exist_db.exist_tb_1 where exist_tb_1.id>1) t LIMIT 999;", - newTestResult().add(driverV2.RuleLevelNotice, "", "LIMIT 查询建议使用ORDER BY"), + newTestResult().add(driverV2.RuleLevelNotice, "", "LIMIT 查询建议使用ORDER BY").addResult(rulepkg.DMLCheckWhereIsInvalid), ) runDefaultRulesInspectCase(t, "select_count: has no where condition(3)", DefaultMysqlInspect(), @@ -4121,6 +4121,10 @@ func Test_DMLCheckInQueryLimit(t *testing.T) { paramValue := "5" rule.Params.SetParamValue(rulepkg.DefaultSingleParamKeyName, paramValue) + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), + "select * from exist_tb_1", + newTestResult()) + runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(), "select * from exist_tb_1 where id in (1,2,3,4,5,6)", newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue)) diff --git a/sqle/driver/mysql/rule/rule.go b/sqle/driver/mysql/rule/rule.go index 8aa40f995b..36fbd5b600 100644 --- a/sqle/driver/mysql/rule/rule.go +++ b/sqle/driver/mysql/rule/rule.go @@ -3122,6 +3122,9 @@ func isSelectCount(selectStmt *ast.SelectStmt) bool { func checkSelectWhere(input *RuleHandlerInput) error { visitor := util.WhereVisitor{} + if input.Rule.Name == DMLCheckWhereIsInvalid { + visitor.WhetherContainNil = true + } switch stmt := input.Node.(type) { case *ast.SelectStmt: if stmt.From == nil { @@ -3150,6 +3153,10 @@ func checkWhere(rule driverV2.Rule, res *driverV2.AuditResults, whereList []ast. addResult(res, rule, DMLCheckWhereIsInvalid) } for _, where := range whereList { + if where == nil { + addResult(res, rule, DMLCheckWhereIsInvalid) + break + } if !util.WhereStmtHasOneColumn(where) { addResult(res, rule, DMLCheckWhereIsInvalid) break diff --git a/sqle/driver/mysql/util/visitor.go b/sqle/driver/mysql/util/visitor.go index 76ac90bb2f..fb3ac11367 100644 --- a/sqle/driver/mysql/util/visitor.go +++ b/sqle/driver/mysql/util/visitor.go @@ -246,12 +246,15 @@ func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { } type WhereVisitor struct { - WhereList []ast.ExprNode + WhereList []ast.ExprNode + WhetherContainNil bool } func (v *WhereVisitor) append(where ast.ExprNode) { if where != nil { v.WhereList = append(v.WhereList, where) + } else if v.WhetherContainNil { + v.WhereList = append(v.WhereList, nil) } }