Skip to content

Commit

Permalink
Merge pull request #2181 from actiontech/sync-v2.9999
Browse files Browse the repository at this point in the history
Sync v2.9999
  • Loading branch information
sjjian authored Dec 22, 2023
2 parents e69df38 + 19e5c92 commit 4ebfeaa
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 28 deletions.
3 changes: 2 additions & 1 deletion spelling_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -456,4 +456,5 @@ multipoint
multilinestring
multipolygon
geometrycollection
charlength
charlength
xmls
32 changes: 5 additions & 27 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,19 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea
if stmt.From == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.From.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.From.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.From.TableRefs)
case *ast.UpdateStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
case *ast.DeleteStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
default:
return nil, nil
Expand Down Expand Up @@ -696,28 +696,6 @@ func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context,
return tableNameCreateTableStmtMap
}

func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) {
var leftType, rightType byte
// onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名
Expand Down Expand Up @@ -3259,7 +3237,7 @@ func checkWhereConditionUseIndex(ctx *session.Context, whereVisitor *util.WhereW
continue
}

tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef)
tableNameCreateTableStmtMap := ctx.GetTableNameCreateTableStmtMap(whereExpr.TableRef)
util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) {
switch x := expr.(type) {
case *ast.ColumnNameExpr:
Expand Down Expand Up @@ -5465,7 +5443,7 @@ func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) {
// 如果SQL没有JOIN多表,则不需要审核
return true, fmt.Errorf("sql have not join node")
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode)
tableNameCreateTableStmtMap := input.Ctx.GetTableNameCreateTableStmtMap(joinNode)
tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap))
for tableName, createTableStmt := range tableNameCreateTableStmtMap {
tableIndexes[tableName] = createTableStmt.Constraints
Expand Down
22 changes: 22 additions & 0 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,25 @@ func (c *Context) GetExecutor() *executor.Executor {
func (c *Context) GetTableIndexesInfo(schema, tableName string) ([]*executor.TableIndexesInfo, error) {
return c.e.GetTableIndexesInfo(utils.SupplementalQuotationMarks(schema), utils.SupplementalQuotationMarks(tableName))
}

func (c *Context) GetTableNameCreateTableStmtMap(joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := c.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

0 comments on commit 4ebfeaa

Please sign in to comment.