Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DMLCheckJoinFieldType rule not triggered #1892

Merged
merged 5 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3825,6 +3825,35 @@ ALTER TABLE exist_db.exist_tb_1 ADD INDEX idx_v3(v3);

func Test_DMLCheckJoinFieldType(t *testing.T) {
rule := rulepkg.RuleHandlerMap[rulepkg.DMLCheckJoinFieldType].Rule
runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
`SELECT * FROM exist_tb_1 t1
LEFT JOIN (SELECT id FROM exist_tb_2 WHERE id < 100) t2
ON t1.id = t2.id`,
newTestResult())

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
`SELECT * FROM exist_tb_1 t1
LEFT JOIN (SELECT id FROM exist_tb_2 WHERE id < 100) t2
ON CAST(t1.id AS FLOAT) = t2.id`,
newTestResult().addResult(rulepkg.DMLCheckJoinFieldType))

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
`SELECT * FROM exist_tb_1 t1
LEFT JOIN (SELECT id FROM exist_tb_2 WHERE id < 100) t2
ON CAST(t1.id AS FLOAT) = CONVERT(t2.id, FLOAT)`,
newTestResult())

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
`SELECT * FROM exist_tb_1 t1 LEFT JOIN
(SELECT id FROM exist_tb_2 t2 JOIN exist_tb_1 t1 ON t2.id = t1.id WHERE t2.id < 100 ) t3
ON CAST(t1.id AS FLOAT) = t3.id`,
newTestResult()) // 不支持子查询涉及多表作为临时表的来源,不会触发

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
`SELECT * FROM exist_tb_1 t1
LEFT JOIN (SELECT id FROM exist_tb_2 WHERE id < 100) t2
ON (t1.id,t1.v1) = (t2.v2,t2.id)`,
newTestResult()) // 连接键中包含多列,不会触发

runSingleRuleInspectCase(rule, t, "", DefaultMysqlInspect(),
`select * from exist_tb_1 t1 left join exist_tb_2 t2 on t1.id = t2.id left join exist_tb_3 t3
Expand Down
89 changes: 83 additions & 6 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2534,8 +2534,35 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea
return tableNameCreateTableStmtMap, onConditions
}

func getCreateTableAndOnConditionForJoinType(input *RuleHandlerInput) (map[string]*ast.CreateTableStmt, []*ast.OnCondition) {
var ctx *session.Context = input.Ctx
var joinStmt *ast.Join
switch stmt := input.Node.(type) {
case *ast.SelectStmt:
if stmt.From == nil {
return nil, nil
}
joinStmt = stmt.From.TableRefs
case *ast.UpdateStmt:
if stmt.TableRefs == nil {
return nil, nil
}
joinStmt = stmt.TableRefs.TableRefs
case *ast.DeleteStmt:
if stmt.TableRefs == nil {
return nil, nil
}
joinStmt = stmt.TableRefs.TableRefs
default:
return nil, nil
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMapForJoinType(ctx, joinStmt)
onConditions := util.GetTableFromOnCondition(joinStmt)
return tableNameCreateTableStmtMap, onConditions
}

func checkJoinFieldType(input *RuleHandlerInput) error {
tableNameCreateTableStmtMap, onConditions := getCreateTableAndOnCondition(input)
tableNameCreateTableStmtMap, onConditions := getCreateTableAndOnConditionForJoinType(input)
if tableNameCreateTableStmtMap == nil && onConditions == nil {
return nil
}
Expand Down Expand Up @@ -2602,6 +2629,33 @@ func checkOnCondition(resultSetNode ast.ResultSetNode) (checkSuccessfully, conti
return true, true
}

func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context, joinStmt *ast.Join) map[string]*ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
tableNameExtractor := util.TableNameExtractor{TableNames: map[string]*ast.TableName{}}
tableSource.Source.Accept(&tableNameExtractor)
if len(tableNameExtractor.TableNames) > 1 {
log.Logger().Warn("规则:建议JOIN字段类型保持一致,不支持JOIN的表由多表构成")
continue
}
for tableName, tableNameStmt := range tableNameExtractor.TableNames {
createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
tableNameCreateTableStmtMap[tableName] = createTableStmt
// !临时方案:只支持别名对应的临时表只含有一个表,不支持JOIN的表由多表构成
// TODO AS语句中的别名作为表的别名时,表别名所对应的表可能是数据库的库表,也有可能是语句中构建的临时表。其中,临时表的可能性有很多种,例如:子查询的结果作为表,JOIN得到的表,其中还可能存在层层嵌套的关系。如果要获取到ON语句块中列的实际表名称,需要递归地构建别名:列名:表名(这个表名可能还是别名)的映射关系
if tableSource.AsName.String() != "" {
tableNameCreateTableStmtMap[tableSource.AsName.String()] = createTableStmt
}
// TODO: 跨库的 JOIN 无法区分
}
}
return tableNameCreateTableStmtMap
}

func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string]*ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
Expand All @@ -2625,15 +2679,38 @@ func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *a

func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) {
var leftType, rightType byte

// onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名
if binaryOperation, ok := onCondition.Expr.(*ast.BinaryOperationExpr); ok {
if columnName, ok := binaryOperation.L.(*ast.ColumnNameExpr); ok {
leftType = getColumnType(columnName, createTableStmtMap)
switch node := binaryOperation.L.(type) {
// 当使用类型转换时 列的类型被显式转化为对应类型 支持CAST和CONVERT函数
case *ast.FuncCastExpr:
leftType = node.Tp.Tp
default:
// 默认获取子树的所有列 对应等号一侧 一般连接键只会有一个 不支持多个列的组合
lVisitor := util.ColumeNameVisitor{}
binaryOperation.L.Accept(&lVisitor)
if len(lVisitor.ColumeNameList) > 1 {
log.Logger().Warn("规则:建议JOIN字段类型保持一致,连接键不支持多个列的组合")
}
if len(lVisitor.ColumeNameList) == 1 {
leftType = getColumnType(lVisitor.ColumeNameList[0], createTableStmtMap)
}
}

if columnName, ok := binaryOperation.R.(*ast.ColumnNameExpr); ok {
rightType = getColumnType(columnName, createTableStmtMap)
switch node := binaryOperation.R.(type) {
case *ast.FuncCastExpr:
rightType = node.Tp.Tp
default:
rVisitor := util.ColumeNameVisitor{}
binaryOperation.R.Accept(&rVisitor)
if len(rVisitor.ColumeNameList) > 1 {
log.Logger().Warn("规则:建议JOIN字段类型保持一致,连接键不支持多个列的组合")
}
if len(rVisitor.ColumeNameList) > 0 {
rightType = getColumnType(rVisitor.ColumeNameList[0], createTableStmtMap)
}
}

}

return leftType, rightType
Expand Down
16 changes: 16 additions & 0 deletions sqle/driver/mysql/util/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,19 @@ func (v *SelectVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
func (v *SelectVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

type ColumeNameVisitor struct {
ColumeNameList []*ast.ColumnNameExpr
}

func (v *ColumeNameVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch stmt := in.(type) {
case *ast.ColumnNameExpr:
v.ColumeNameList = append(v.ColumeNameList, stmt)
}
return in, false
}

func (v *ColumeNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}
32 changes: 32 additions & 0 deletions sqle/driver/mysql/util/visitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,35 @@ func TestSelectFieldExtractor(t *testing.T) {
})
}
}

func TestColumeNameVisitor(t *testing.T) {
tests := []struct {
input string
columnCount uint
}{
{"SELECT * FROM t1", 0}, //不包含列
{"SELECT a,b,c FROM t1 WHERE id > 1", 4}, //使用不等号
{"SELECT COUNT(*) FROM t1", 0}, //使用函数并不包含列
{"SELECT a,COUNT(*) FROM t1 GROUP BY a", 2}, //使用函数包含列
{"SELECT * FROM table1 INNER JOIN table2 ON table1.id = table2.table1_id", 2}, //使用JOIN
{"SELECT * FROM table1 WHERE id IN ( SELECT id FROM table2 WHERE age > 30)", 3}, //使用子查询
{"SELECT UPPER(name), LENGTH(comments) FROM table1", 2}, //使用函数
{"SELECT CAST(price AS DECIMAL(10,2))FROM products", 1}, //使用类型转换
{"SELECT * FROM table1 INNER JOIN table2 ON table1.id = table2.table1_id INNER JOIN table3 ON table2.id = table3.table2_id", 4}, //使用JOIN嵌套
{"SELECT column1 AS alias1, column2 AS alias2 FROM table1", 2}, //使用列别名
{"SELECT column1 + column2 AS sum_columns FROM table1", 2},
{"SELECT t1.column1 AS t1_col1, t2.column2 AS t2_col2 FROM table1 t1 INNER JOIN table2 t2 ON t1.id = t2.t1_id", 4}, //不带AS的表别名
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
stmt, err := parser.New().ParseOneStmt(tt.input, "", "")
assert.NoError(t, err)

visitor := &ColumeNameVisitor{}
stmt.Accept(visitor)

assert.Equal(t, tt.columnCount, uint(len(visitor.ColumeNameList)))
})
}
}
Loading