Skip to content

Commit

Permalink
judge where is nil
Browse files Browse the repository at this point in the history
  • Loading branch information
hasa1K committed Nov 9, 2023
1 parent c4ff98c commit 4bcc9d3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
22 changes: 21 additions & 1 deletion sqle/driver/mysql/audit_offline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) {
"select id from exist_db.exist_tb_1 where id > 1;",
noResult,
},
{
"select_from: has where condition",
"select id from exist_db.exist_tb_1;",
whereIsInvalid,
},
{
"select_from: no where condition(1)",
"select id from exist_db.exist_tb_1;",
Expand Down Expand Up @@ -129,7 +134,7 @@ func TestCheckWhereInvalidOffline(t *testing.T) {
{
"select_from: no where condition(6)",
"select id from (select * from exist_db.exist_tb_1 where exist_tb_1.id>1) t;",
noResult,
whereIsInvalid,
},
// UPDATE
{
Expand Down Expand Up @@ -212,6 +217,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) {
"select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where 1=1);",
whereIsInvalid,
},
{
"use exists",
"select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2);",
whereIsInvalid,
},
{
"use exists",
"select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where exists (select 1 from exist_db.exist_db_3 t3 where t1.id=t2.id and t2.id=t3.id));",
Expand All @@ -222,6 +232,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) {
"select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where exists (select 1 from exist_db.exist_db_3 t3 where 1=1));",
whereIsInvalid,
},
{
"use exists",
"select * from exist_db.exist_tb_1 t1 where exists (select 1 from exist_db.exist_tb_2 t2 where exists (select 1 from exist_db.exist_db_3));",
whereIsInvalid,
},
{
"use not exists",
"select * from exist_db.exist_tb_1 t1 where not exists (select 1 from exist_db.exist_tb_2 t2 where t1.id=t2.id);",
Expand All @@ -232,6 +247,11 @@ func TestCheckWhereInvalidOffline(t *testing.T) {
"select * from exist_db.exist_tb_1 t1 where not exists (select 1 from exist_db.exist_tb_2 t2 where 1=1);",
whereIsInvalid,
},
{
"use not exists",
"select * from exist_db.exist_tb_1 t1 where not exists (select 1 from exist_db.exist_tb_2);",
whereIsInvalid,
},
}
offlineInspect := DefaultMysqlInspectOffline()
for _, testCase := range testCases {
Expand Down
6 changes: 5 additions & 1 deletion sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ func TestCheckWhereInvalid(t *testing.T) {

runDefaultRulesInspectCase(t, "select_count: has where condition(2)", DefaultMysqlInspect(),
"select id from (select * from exist_db.exist_tb_1 where exist_tb_1.id>1) t LIMIT 999;",
newTestResult().add(driverV2.RuleLevelNotice, "", "LIMIT 查询建议使用ORDER BY"),
newTestResult().add(driverV2.RuleLevelNotice, "", "LIMIT 查询建议使用ORDER BY").addResult(rulepkg.DMLCheckWhereIsInvalid),
)

runDefaultRulesInspectCase(t, "select_count: has no where condition(3)", DefaultMysqlInspect(),
Expand Down Expand Up @@ -4121,6 +4121,10 @@ func Test_DMLCheckInQueryLimit(t *testing.T) {
paramValue := "5"
rule.Params.SetParamValue(rulepkg.DefaultSingleParamKeyName, paramValue)

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select * from exist_tb_1",
newTestResult())

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
"select * from exist_tb_1 where id in (1,2,3,4,5,6)",
newTestResult().addResult(rulepkg.DMLCheckInQueryNumber, 6, paramValue))
Expand Down
7 changes: 7 additions & 0 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -3122,6 +3122,9 @@ func isSelectCount(selectStmt *ast.SelectStmt) bool {
func checkSelectWhere(input *RuleHandlerInput) error {

visitor := util.WhereVisitor{}
if input.Rule.Name == DMLCheckWhereIsInvalid {
visitor.WhetherContainNil = true
}
switch stmt := input.Node.(type) {
case *ast.SelectStmt:
if stmt.From == nil {
Expand Down Expand Up @@ -3150,6 +3153,10 @@ func checkWhere(rule driverV2.Rule, res *driverV2.AuditResults, whereList []ast.
addResult(res, rule, DMLCheckWhereIsInvalid)
}
for _, where := range whereList {
if where == nil {
addResult(res, rule, DMLCheckWhereIsInvalid)
break
}
if !util.WhereStmtHasOneColumn(where) {
addResult(res, rule, DMLCheckWhereIsInvalid)
break
Expand Down
5 changes: 4 additions & 1 deletion sqle/driver/mysql/util/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,15 @@ func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
}

type WhereVisitor struct {
WhereList []ast.ExprNode
WhereList []ast.ExprNode
WhetherContainNil bool
}

func (v *WhereVisitor) append(where ast.ExprNode) {
if where != nil {
v.WhereList = append(v.WhereList, where)
} else if v.WhetherContainNil {
v.WhereList = append(v.WhereList, nil)
}
}

Expand Down

0 comments on commit 4bcc9d3

Please sign in to comment.