Skip to content

Commit

Permalink
Merge pull request #2126 from actiontech/fix-issue2114
Browse files Browse the repository at this point in the history
check xml sql
  • Loading branch information
taolx0 authored Dec 7, 2023
2 parents 688a2a6 + 7b76f30 commit be67d04
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 14 deletions.
39 changes: 39 additions & 0 deletions sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7526,3 +7526,42 @@ func TestDDLCheckCharLength(t *testing.T) {
})
}
}

func Test_CheckMybatisSQLIndex(t *testing.T) {
e, _, err := executor.NewMockExecutor()
assert.NoError(t, err)

inspect1 := NewMockInspect(e)
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, "select * from exist_tb_1", newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, "select * from exist_tb_1 where id=?", newTestResult())
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, "select * from exist_tb_2 where v1=?", newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, "select * from exist_tb_2 t2 where t2.id=?", newTestResult())
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, "select * from exist_tb_2 t2 where t2.v1=?", newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `select * from exist_tb_2 t2 left join exist_tb_1 t1 on t1.id=t2.id where t2.v1=?`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `select * from exist_tb_2 t2 left join exist_tb_1 t1 on t1.id=t2.id`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `select * from exist_tb_2 t2 left join exist_tb_1 t1 on t1.id=t2.id where t1.id=?`, newTestResult())
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `select * from exist_tb_2 where id in (select id from exist_tb_1 where v1=?)`, newTestResult())
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `select * from exist_tb_2 where v2 in (select v1 from exist_tb_1 where id=?)`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `delete from exist_tb_2`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `delete from exist_tb_2 where id=?`, newTestResult())
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `delete from exist_tb_2 where v1=?`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `update exist_tb_2 set id=1`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `update exist_tb_2 set v1=? where id=?`, newTestResult())
runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t,
"", inspect1, `update exist_tb_2 set v1=? where v2=?`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex))
}
99 changes: 86 additions & 13 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -5365,6 +5365,58 @@ func checkIndexOption(input *RuleHandlerInput) error {
return nil
}

func isColumnUsingIndex(column string, constraints []*ast.Constraint) bool {
for _, constraint := range constraints {
for _, key := range constraint.Keys {
if key.Column.Name.L == column {
return true
}
}
}
return false
}

func checkWhereCondationUseIndex(ctx *session.Context, whereVisitor *util.WhereWithTableVisitor) bool {
for _, whereExpr := range whereVisitor.WhereStmts {
if whereExpr.WhereStmt == nil {
return false
}

isUsingIndex := false

if whereExpr.TableRef == nil {
continue
}

tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef)
util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) {
switch x := expr.(type) {
case *ast.ColumnNameExpr:
tableName := x.Name.Table.L
columnName := x.Name.Name.L
// 代表单表查询并且没有使用表别名
if tableName == "" {
for _, createTableStmt := range tableNameCreateTableStmtMap {
if isColumnUsingIndex(columnName, createTableStmt.Constraints) {
isUsingIndex = true
}
}
} else {
createStmt, ok := tableNameCreateTableStmtMap[tableName]
if ok && isColumnUsingIndex(columnName, createStmt.Constraints) {
isUsingIndex = true
}
}
}
return false
}, *whereExpr.WhereStmt)
if !isUsingIndex {
return false
}
}
return true
}

func checkExplain(input *RuleHandlerInput) error {
// sql from MyBatis XML file is not the executable sql. so can't do explain for it.
// TODO(@wy) ignore explain when audit Mybatis file
Expand All @@ -5381,6 +5433,25 @@ func checkExplain(input *RuleHandlerInput) error {
if err != nil {
// TODO: check dml related table or database is created, if not exist, explain will executed failure.
log.NewEntry().Errorf("get execution plan failed, sqle: %v, error: %v", input.Node.Text(), err)

// xml解析出来的sql获取执行计划会失败
// 需要根据查询条件中的字段判断是否使用了索引
if input.Rule.Name != DMLCheckExplainUsingIndex {
return nil
}
// 验证where条件是否使用了索引字段
wv := &util.WhereWithTableVisitor{}
input.Node.Accept(wv)
if !checkWhereCondationUseIndex(input.Ctx, wv) {
addResult(input.Res, input.Rule, input.Rule.Name)
return nil
}
// 验证连表查询中连接字段是否使用索引
isUsingIndex, err := judgeJoinFieldUseIndex(input)
if err == nil && !isUsingIndex {
addResult(input.Res, input.Rule, input.Rule.Name)
}

return nil
}
for _, record := range epRecords {
Expand Down Expand Up @@ -7486,15 +7557,20 @@ func isColumnUseLeftMostPrefix(allCols []string, constraints []*ast.Constraint)
return true
}

func checkJoinFieldUseIndex(input *RuleHandlerInput) error {
isUsingIndex, err := judgeJoinFieldUseIndex(input)
if err == nil && !isUsingIndex {
addResult(input.Res, input.Rule, input.Rule.Name)
}
return nil
}

/*
checkJoinFieldUseIndex 判断Join语句中被驱动表中作为连接条件的列是否属于索引
judgeJoinFieldUseIndex 判断Join语句中被驱动表中作为连接条件的列是否属于索引
触发条件:
A. CrossJoin和RightJoin (选择驱动表的情况复杂:随着数据变化而变化,因此都判断)
A. CrossJoin,RightJoin和LeftJoin
1. 分别判断ON USING WHERE中的连接条件是否有索引
B. LeftJoin (选择驱动表的情况固定:Join右侧的表为被驱动表)
1. ON和USING,判断LeftJoin的被驱动表 (右侧的表) 的连接条件是否有索引
2. 判断WHERE中的连接条件是否有索引
连接条件:等值条件两侧为不同表的列
支持情况:
支持:
Expand All @@ -7503,12 +7579,11 @@ checkJoinFieldUseIndex 判断Join语句中被驱动表中作为连接条件的
不支持:
1. 子查询中JOIN多表的判断
*/
func checkJoinFieldUseIndex(input *RuleHandlerInput) error {

func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) {
joinNode := getJoinNodeFromNode(input.Node)
if doesNotJoinTables(joinNode) {
// 如果SQL没有JOIN多表,则不需要审核
return nil
return true, fmt.Errorf("sql have not join node")
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode)
tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap))
Expand All @@ -7517,16 +7592,14 @@ func checkJoinFieldUseIndex(input *RuleHandlerInput) error {
}

if joinNodes, hasIndex := joinConditionInJoinNodeHasIndex(input.Ctx, joinNode, tableIndexes); joinNodes && !hasIndex {
addResult(input.Res, input.Rule, input.Rule.Name)
return nil
return false, nil
}

whereStmt := getWhereStmtFromNode(input.Node)
if joinNodes, hasIndex := joinConditionInWhereStmtHasIndex(input.Ctx, joinNode, whereStmt, tableIndexes); joinNodes && !hasIndex {
addResult(input.Res, input.Rule, input.Rule.Name)
return nil
return false, nil
}
return nil
return true, nil
}

func joinConditionInWhereStmtHasIndex(ctx *session.Context, joinNode *ast.Join, whereStmt ast.ExprNode, tableIndex map[string][]*ast.Constraint) (joinTables, hasIndex bool) {
Expand Down
30 changes: 29 additions & 1 deletion sqle/driver/mysql/util/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {

type WhereVisitor struct {
WhereList []ast.ExprNode
WhetherContainNil bool // 是否需要包含空的where,例如select * from t1 该语句的where为空
WhetherContainNil bool // 是否需要包含空的where,例如select * from t1 该语句的where为空
}

func (v *WhereVisitor) append(where ast.ExprNode) {
Expand Down Expand Up @@ -328,3 +328,31 @@ func (v *FuncCallExprVisitor) Enter(in ast.Node) (out ast.Node, skipChildren boo
func (v *FuncCallExprVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

type WhereWithTable struct {
WhereStmt *ast.ExprNode
TableRef *ast.Join
}

type WhereWithTableVisitor struct {
WhereStmts []*WhereWithTable
}

func (v *WhereWithTableVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch stmt := in.(type) {
case *ast.SelectStmt:
if stmt.From == nil { //If from is null, skip check. EX: select 1;select version;
return in, false
}
v.WhereStmts = append(v.WhereStmts, &WhereWithTable{WhereStmt: &stmt.Where, TableRef: stmt.From.TableRefs})
case *ast.DeleteStmt:
v.WhereStmts = append(v.WhereStmts, &WhereWithTable{WhereStmt: &stmt.Where, TableRef: stmt.TableRefs.TableRefs})
case *ast.UpdateStmt:
v.WhereStmts = append(v.WhereStmts, &WhereWithTable{WhereStmt: &stmt.Where, TableRef: stmt.TableRefs.TableRefs})
}
return in, false
}

func (v *WhereWithTableVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

0 comments on commit be67d04

Please sign in to comment.