diff --git a/sqle/driver/mysql/audit_test.go b/sqle/driver/mysql/audit_test.go index a2b2811184..827ed62907 100644 --- a/sqle/driver/mysql/audit_test.go +++ b/sqle/driver/mysql/audit_test.go @@ -7526,3 +7526,42 @@ func TestDDLCheckCharLength(t *testing.T) { }) } } + +func Test_CheckMybatisSQLIndex(t *testing.T) { + e, _, err := executor.NewMockExecutor() + assert.NoError(t, err) + + inspect1 := NewMockInspect(e) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, "select * from exist_tb_1", newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, "select * from exist_tb_1 where id=?", newTestResult()) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, "select * from exist_tb_2 where v1=?", newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, "select * from exist_tb_2 t2 where t2.id=?", newTestResult()) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, "select * from exist_tb_2 t2 where t2.v1=?", newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `select * from exist_tb_2 t2 left join exist_tb_1 t1 on t1.id=t2.id where t2.v1=?`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `select * from exist_tb_2 t2 left join exist_tb_1 t1 on t1.id=t2.id`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `select * from exist_tb_2 t2 left join exist_tb_1 t1 on t1.id=t2.id where t1.id=?`, newTestResult()) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `select * from exist_tb_2 where id in (select id from exist_tb_1 where v1=?)`, newTestResult()) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `select * from exist_tb_2 where v2 in (select v1 from exist_tb_1 where id=?)`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `delete from exist_tb_2`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `delete from exist_tb_2 where id=?`, newTestResult()) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `delete from exist_tb_2 where v1=?`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `update exist_tb_2 set id=1`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `update exist_tb_2 set v1=? where id=?`, newTestResult()) + runSingleRuleInspectCase(rulepkg.RuleHandlerMap[rulepkg.DMLCheckExplainUsingIndex].Rule, t, + "", inspect1, `update exist_tb_2 set v1=? where v2=?`, newTestResult().addResult(rulepkg.DMLCheckExplainUsingIndex)) +} diff --git a/sqle/driver/mysql/rule/rule.go b/sqle/driver/mysql/rule/rule.go index d8a34e2d17..e371a8f2eb 100644 --- a/sqle/driver/mysql/rule/rule.go +++ b/sqle/driver/mysql/rule/rule.go @@ -5365,6 +5365,58 @@ func checkIndexOption(input *RuleHandlerInput) error { return nil } +func isColumnUsingIndex(column string, constraints []*ast.Constraint) bool { + for _, constraint := range constraints { + for _, key := range constraint.Keys { + if key.Column.Name.L == column { + return true + } + } + } + return false +} + +func checkWhereCondationUseIndex(ctx *session.Context, whereVisitor *util.WhereWithTableVisitor) bool { + for _, whereExpr := range whereVisitor.WhereStmts { + if whereExpr.WhereStmt == nil { + return false + } + + isUsingIndex := false + + if whereExpr.TableRef == nil { + continue + } + + tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef) + util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) { + switch x := expr.(type) { + case *ast.ColumnNameExpr: + tableName := x.Name.Table.L + columnName := x.Name.Name.L + // 代表单表查询并且没有使用表别名 + if tableName == "" { + for _, createTableStmt := range tableNameCreateTableStmtMap { + if isColumnUsingIndex(columnName, createTableStmt.Constraints) { + isUsingIndex = true + } + } + } else { + createStmt, ok := tableNameCreateTableStmtMap[tableName] + if ok && isColumnUsingIndex(columnName, createStmt.Constraints) { + isUsingIndex = true + } + } + } + return false + }, *whereExpr.WhereStmt) + if !isUsingIndex { + return false + } + } + return true +} + func checkExplain(input *RuleHandlerInput) error { // sql from MyBatis XML file is not the executable sql. so can't do explain for it. // TODO(@wy) ignore explain when audit Mybatis file @@ -5381,6 +5433,25 @@ func checkExplain(input *RuleHandlerInput) error { if err != nil { // TODO: check dml related table or database is created, if not exist, explain will executed failure. log.NewEntry().Errorf("get execution plan failed, sqle: %v, error: %v", input.Node.Text(), err) + + // xml解析出来的sql获取执行计划会失败 + // 需要根据查询条件中的字段判断是否使用了索引 + if input.Rule.Name != DMLCheckExplainUsingIndex { + return nil + } + // 验证where条件是否使用了索引字段 + wv := &util.WhereWithTableVisitor{} + input.Node.Accept(wv) + if !checkWhereCondationUseIndex(input.Ctx, wv) { + addResult(input.Res, input.Rule, input.Rule.Name) + return nil + } + // 验证连表查询中连接字段是否使用索引 + isUsingIndex, err := judgeJoinFieldUseIndex(input) + if err == nil && !isUsingIndex { + addResult(input.Res, input.Rule, input.Rule.Name) + } + return nil } for _, record := range epRecords { @@ -7486,15 +7557,20 @@ func isColumnUseLeftMostPrefix(allCols []string, constraints []*ast.Constraint) return true } +func checkJoinFieldUseIndex(input *RuleHandlerInput) error { + isUsingIndex, err := judgeJoinFieldUseIndex(input) + if err == nil && !isUsingIndex { + addResult(input.Res, input.Rule, input.Rule.Name) + } + return nil +} + /* -checkJoinFieldUseIndex 判断Join语句中被驱动表中作为连接条件的列是否属于索引 +judgeJoinFieldUseIndex 判断Join语句中被驱动表中作为连接条件的列是否属于索引 触发条件: - A. CrossJoin和RightJoin (选择驱动表的情况复杂:随着数据变化而变化,因此都判断) + A. CrossJoin,RightJoin和LeftJoin 1. 分别判断ON USING WHERE中的连接条件是否有索引 - B. LeftJoin (选择驱动表的情况固定:Join右侧的表为被驱动表) - 1. ON和USING,判断LeftJoin的被驱动表 (右侧的表) 的连接条件是否有索引 - 2. 判断WHERE中的连接条件是否有索引 连接条件:等值条件两侧为不同表的列 支持情况: 支持: @@ -7503,12 +7579,11 @@ checkJoinFieldUseIndex 判断Join语句中被驱动表中作为连接条件的 不支持: 1. 子查询中JOIN多表的判断 */ -func checkJoinFieldUseIndex(input *RuleHandlerInput) error { - +func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) { joinNode := getJoinNodeFromNode(input.Node) if doesNotJoinTables(joinNode) { // 如果SQL没有JOIN多表,则不需要审核 - return nil + return true, fmt.Errorf("sql have not join node") } tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode) tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap)) @@ -7517,16 +7592,14 @@ func checkJoinFieldUseIndex(input *RuleHandlerInput) error { } if joinNodes, hasIndex := joinConditionInJoinNodeHasIndex(input.Ctx, joinNode, tableIndexes); joinNodes && !hasIndex { - addResult(input.Res, input.Rule, input.Rule.Name) - return nil + return false, nil } whereStmt := getWhereStmtFromNode(input.Node) if joinNodes, hasIndex := joinConditionInWhereStmtHasIndex(input.Ctx, joinNode, whereStmt, tableIndexes); joinNodes && !hasIndex { - addResult(input.Res, input.Rule, input.Rule.Name) - return nil + return false, nil } - return nil + return true, nil } func joinConditionInWhereStmtHasIndex(ctx *session.Context, joinNode *ast.Join, whereStmt ast.ExprNode, tableIndex map[string][]*ast.Constraint) (joinTables, hasIndex bool) { @@ -7613,6 +7686,12 @@ func (m tableColumnMap) add(tableName, columnName string) { m[tableName][columnName] = struct{}{} } +func (m tableColumnMap) initMap(tables []*ast.TableSource) { + for _, t := range tables { + println(t) + } +} + /* IsIndex diff --git a/sqle/driver/mysql/util/visitor.go b/sqle/driver/mysql/util/visitor.go index 1da043deeb..a19d9f4fcc 100644 --- a/sqle/driver/mysql/util/visitor.go +++ b/sqle/driver/mysql/util/visitor.go @@ -247,7 +247,7 @@ func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { type WhereVisitor struct { WhereList []ast.ExprNode - WhetherContainNil bool // 是否需要包含空的where,例如select * from t1 该语句的where为空 + WhetherContainNil bool // 是否需要包含空的where,例如select * from t1 该语句的where为空 } func (v *WhereVisitor) append(where ast.ExprNode) { @@ -328,3 +328,31 @@ func (v *FuncCallExprVisitor) Enter(in ast.Node) (out ast.Node, skipChildren boo func (v *FuncCallExprVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } + +type WhereWithTable struct { + WhereStmt *ast.ExprNode + TableRef *ast.Join +} + +type WhereWithTableVisitor struct { + WhereStmts []*WhereWithTable +} + +func (v *WhereWithTableVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch stmt := in.(type) { + case *ast.SelectStmt: + if stmt.From == nil { //If from is null, skip check. EX: select 1;select version; + return in, false + } + v.WhereStmts = append(v.WhereStmts, &WhereWithTable{WhereStmt: &stmt.Where, TableRef: stmt.From.TableRefs}) + case *ast.DeleteStmt: + v.WhereStmts = append(v.WhereStmts, &WhereWithTable{WhereStmt: &stmt.Where, TableRef: stmt.TableRefs.TableRefs}) + case *ast.UpdateStmt: + v.WhereStmts = append(v.WhereStmts, &WhereWithTable{WhereStmt: &stmt.Where, TableRef: stmt.TableRefs.TableRefs}) + } + return in, false +} + +func (v *WhereWithTableVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +}