Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: method of optimizer #1964

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 100 additions & 62 deletions sqle/driver/mysql/optimizer/index/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Optimizer struct {
l *logrus.Entry

// tables key is table name, use to match in execution plan.
tables map[string]*tableInSelect
tables map[string] /*table name*/ *tableInSelect

// optimizer options:
calculateCardinalityMaxRow int
Expand Down Expand Up @@ -101,7 +101,7 @@ type OptimizeResult struct {
// 1. when we find a table in single table select statement, we will store the select statement.
// 2. when we find a table in join statement, we will store the join on condition.
type tableInSelect struct {
joinOnColumn string
joinOnColumn map[string]bool
singleTableSel *ast.SelectStmt
}

Expand All @@ -112,7 +112,7 @@ func (o *Optimizer) Optimize(ctx context.Context, selectStmt *ast.SelectStmt) ([
return nil, nil
}

o.parseSelectStmt(selectStmt)
o.parseTableFromSelectStmt(selectStmt)

restoredSQL, err := restoreSelectStmt(selectStmt)
if err != nil {
Expand Down Expand Up @@ -142,21 +142,21 @@ func (o *Optimizer) Optimize(ctx context.Context, selectStmt *ast.SelectStmt) ([
o.l.Infof("need optimize tables: %v", needOptimizedTables)

var results []*OptimizeResult
for _, tbl := range needOptimizedTables {
table, ok := o.tables[tbl]
for _, tableName := range needOptimizedTables {
table, ok := o.tables[tableName]
if !ok {
// given SQL: select * from t1 join t2, there is no join on condition,
continue
}

var result *OptimizeResult
if table.joinOnColumn == "" {
result, err = o.optimizeSingleTable(ctx, tbl, table.singleTableSel)
if len(table.joinOnColumn) > 0 {
result = o.optimizeJoinTable(tableName)
} else {
result, err = o.optimizeSingleTable(ctx, tableName, table.singleTableSel)
if err != nil {
return nil, errors.Wrapf(err, "optimize single table %s", tbl)
return nil, errors.Wrapf(err, "optimize single table %s", tableName)
}
} else {
result = o.optimizeJoinTable(tbl)
}
if result != nil {
results = append(results, result)
Expand All @@ -166,64 +166,97 @@ func (o *Optimizer) Optimize(ctx context.Context, selectStmt *ast.SelectStmt) ([
return results, nil
}

// SelectStmt:
// 1. single select on single table
// 2. single select on multiple tables, such join
// 3. multi select on multiple tables, such subqueries
func (o *Optimizer) parseSelectStmt(ss *ast.SelectStmt) {
visitor := util.SelectStmtExtractor{}
ss.Accept(&visitor)

for _, ss := range visitor.SelectStmts {
if ss.From == nil {
continue
func (o *Optimizer) parseTableFromSelectNode(stmt *ast.SelectStmt) {
if stmt.From == nil {
return
}
joinNode := stmt.From.TableRefs
if util.DoesNotJoinTables(joinNode) {
// cache single table
if joinNode.Left == nil {
return
}
if table, ok := joinNode.Left.(*ast.TableSource); ok {
// var name string := table.AsName.O
if tableName, ok := table.Source.(*ast.TableName); ok {
if tableName.Name.O != "" {
tableInSelect := o.getTableInSelect(tableName.Name.O)
tableInSelect.singleTableSel = stmt
}
}
if table.AsName.O != "" {
tableInSelect := o.getTableInSelect(table.AsName.O)
tableInSelect.singleTableSel = stmt
}
}
}
o.parseTableFromJoinNode(joinNode)
}

left := ss.From.TableRefs.Left
right := ss.From.TableRefs.Right
func (o *Optimizer) parseTableFromJoinNode(joinNode *ast.Join) {
// 深度遍历左子树类型为ast.Join的节点 一旦有节点是JOIN两表的节点,并且没有连接条件,则返回
if leftNode, ok := joinNode.Left.(*ast.Join); ok {
o.parseTableFromJoinNode(leftNode)
}

if right == nil { // means single table select
leftTable, ok := left.(*ast.TableSource)
if !ok {
if util.IsJoinConditionInOnClause(joinNode) {
columnNames := util.GetJoinedColumnNameExprInOnClause(joinNode)
for _, columnName := range columnNames {
if columnName.Name.Table.O == "" {
/*
unsupport sqls like
ON (column_1 = column_2) should check column belongs to which table
*/
continue
} else {
/*
support sqls like
ON table_1.column_1=table_2.column_1
ON t1.column_1 = COALESCE(a.c1, b.c1)
ON (t1.column_1,t1.column_2 = t2.column_1,t2.column_2)
ON table_1.column_1=table_2.column_1 AND table_1.column_2=table_2.column_2
*/
tableInSelect := o.getTableInSelect(columnName.Name.Table.O)
tableInSelect.joinOnColumn[columnName.Name.Name.O] = true
}
}

if leftTable.AsName.L != "" {
o.tables[leftTable.AsName.O] = &tableInSelect{singleTableSel: ss}
}
// may appear: select * from (select v1,v2 from t1 where v1 = 2) as t1
if source, ok := leftTable.Source.(*ast.TableName); ok {
o.tables[source.Name.O] = &tableInSelect{singleTableSel: ss}
}
} else {
if ss.From.TableRefs.On != nil {
boe, ok := ss.From.TableRefs.On.Expr.(*ast.BinaryOperationExpr)
if !ok {
continue
}

leftCNE, ok := boe.L.(*ast.ColumnNameExpr)
if !ok {
continue
}
rightCNE, ok := boe.R.(*ast.ColumnNameExpr)
if !ok {
continue
}
o.tables[leftCNE.Name.Table.O] = &tableInSelect{joinOnColumn: leftCNE.Name.Name.L}
o.tables[rightCNE.Name.Table.O] = &tableInSelect{joinOnColumn: rightCNE.Name.Name.L}

} else if ss.From.TableRefs.Using != nil {
//FIXME Panic Here by SQL SELECT * FROM table_1 JOIN table_2 on table_1.id = table_2.id JOIN table_3 USING (column_name); USING在最后并且是多表JOIN的最后
leftTableName := left.(*ast.TableSource).Source.(*ast.TableName).Name.O
rightTableName := right.(*ast.TableSource).Source.(*ast.TableName).Name.O
for _, col := range ss.From.TableRefs.Using {
o.tables[leftTableName] = &tableInSelect{joinOnColumn: col.Name.L}
o.tables[rightTableName] = &tableInSelect{joinOnColumn: col.Name.L}
}
}
}
if util.IsJoinConditionInUsingClause(joinNode) {
left, right := util.GetJoinedTableName(joinNode)
if left == nil || right == nil {
return
}
leftInSelect := o.getTableInSelect(left.Name.O)
rightInSelect := o.getTableInSelect(right.Name.O)
for _, columnInUsing := range joinNode.Using {
leftInSelect.joinOnColumn[columnInUsing.Name.O] = true
rightInSelect.joinOnColumn[columnInUsing.Name.O] = true
}
}

}

func (o *Optimizer) getTableInSelect(tableName string) *tableInSelect {
if o.tables[tableName] == nil {
// if not exist, initialize it
o.tables[tableName] = &tableInSelect{joinOnColumn: make(map[string]bool)}
}
return o.tables[tableName]
}

/*
traverse from select stmt and extract tables in it
1. if node is not a join node that join two tables, cache table into tableInSelect.singleTableSel do not fill the joinOnColumn.
2. if node is a join node that join two tables(using ON condition or Using condition) , cache join condition and only fill tableInSelect.joinOnColumn
*/
func (o *Optimizer) parseTableFromSelectStmt(selectStmt *ast.SelectStmt) {
visitor := util.SelectStmtExtractor{}
selectStmt.Accept(&visitor)

for _, stmt := range visitor.SelectStmts {
o.parseTableFromSelectNode(stmt)
}
}

func (o *Optimizer) optimizeSingleTable(ctx context.Context, tbl string, ss *ast.SelectStmt) (*OptimizeResult, error) {
Expand Down Expand Up @@ -286,10 +319,15 @@ func (o *Optimizer) optimizeSingleTable(ctx context.Context, tbl string, ss *ast
}

func (o *Optimizer) optimizeJoinTable(tbl string) *OptimizeResult {
table := o.getTableInSelect(tbl)
indexColumns := make([]string, 0, len(table.joinOnColumn))
for columnName := range table.joinOnColumn {
indexColumns = append(indexColumns, columnName)
}
return &OptimizeResult{
TableName: tbl,
IndexedColumns: []string{o.tables[tbl].joinOnColumn},
Reason: fmt.Sprintf("字段 %s 为被驱动表 %s 上的关联字段", o.tables[tbl].joinOnColumn, tbl),
IndexedColumns: indexColumns,
Reason: fmt.Sprintf("字段 %s 为被驱动表 %s 上的关联字段", indexColumns, tbl),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用%s format []string 最终打印的格式是什么样的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于字符串切片indexColumns:=[]string{column_1,column_2,column_3}
使用%s format []string 最终打印的格式是:[column_1 column_2 column_3]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用string 的 join 手工格式一下

}
}

Expand Down
24 changes: 13 additions & 11 deletions sqle/driver/mysql/optimizer/index/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func TestOptimizer_Optimize(t *testing.T) {
{"select count(distinct `v2`)", [][]string{cardinalityHead, {"1000"}}},
},
nil,
[]*OptimizeResult{{"EXIST_TB_5", []string{"v2","v1"}, ""}},
[]*OptimizeResult{{"EXIST_TB_5", []string{"v2", "v1"}, ""}},
},
{
"select v2 from EXIST_TB_5 as tb5 where v1 = '1'",
Expand Down Expand Up @@ -362,9 +362,11 @@ func TestOptimizer_Optimize(t *testing.T) {
optimizeResults, err := o.Optimize(context.TODO(), ss.(*ast.SelectStmt))
assert.NoError(t, err)
assert.Equal(t, len(tt.output), len(optimizeResults))
for i, want := range tt.output {
assert.Equal(t, want.TableName, optimizeResults[i].TableName)
assert.Equal(t, want.IndexedColumns, optimizeResults[i].IndexedColumns)
if len(tt.output) == len(optimizeResults) {
for i, want := range tt.output {
assert.Equal(t, want.TableName, optimizeResults[i].TableName)
assert.Equal(t, want.IndexedColumns, optimizeResults[i].IndexedColumns)
}
}
mocker.MatchExpectationsInOrder(true)
assert.NoError(t, mocker.ExpectationsWereMet())
Expand All @@ -376,18 +378,18 @@ func TestOptimizer_parseSelectStmt(t *testing.T) {
t.Parallel()
tests := []struct {
input string
sel map[string] /*table name*/ string /*select SQL*/
join map[string] /*table name*/ string /*join on column*/
sel map[string] /*table name*/ string /*select SQL*/
join map[string] /*table name*/ map[string]bool /*join on column*/
}{
// single select(single table)
{"select 1", nil, nil},
{"select * from t1", map[string]string{"t1": "SELECT * FROM t1"}, nil},
{"select * from t1 as t2", map[string]string{"t2": "SELECT * FROM t1 AS t2", "t1": "SELECT * FROM t1 AS t2"}, nil},
// single select(multi table/join)
{"select * from t1 join t2 on t1.id = t2.id", nil, map[string]string{"t1": "id", "t2": "id"}},
{"select * from t1 left join t2 on t1.id = t2.id", nil, map[string]string{"t1": "id", "t2": "id"}},
{"select * from t1 right join t2 on t1.id = t2.id", nil, map[string]string{"t1": "id", "t2": "id"}},
{"select * from t1 as t1_alias join t2 as t2_alias on t1_alias.id = t2_alias.id", nil, map[string]string{"t1_alias": "id", "t2_alias": "id"}},
{"select * from t1 join t2 on t1.id = t2.id", nil, map[string]map[string]bool{"t1": {"id": true}, "t2": {"id": true}}},
{"select * from t1 left join t2 on t1.id = t2.id", nil, map[string]map[string]bool{"t1": {"id": true}, "t2": {"id": true}}},
{"select * from t1 right join t2 on t1.id = t2.id", nil, map[string]map[string]bool{"t1": {"id": true}, "t2": {"id": true}}},
{"select * from t1 as t1_alias join t2 as t2_alias on t1_alias.id = t2_alias.id", nil, map[string]map[string]bool{"t1_alias": {"id": true}, "t2_alias": {"id": true}}},
// multi select
{"select * from (select * from t1) as t2", map[string]string{"t2": "SELECT * FROM (SELECT * FROM (t1)) AS t2", "t1": "SELECT * FROM t1"}, nil},
{"select * from t1 where id = (select * from t2)", map[string]string{"t1": "SELECT * FROM t1 WHERE id=(SELECT * FROM t2)", "t2": "SELECT * FROM t2"}, nil},
Expand All @@ -398,7 +400,7 @@ func TestOptimizer_parseSelectStmt(t *testing.T) {
assert.NoError(t, err)

o := Optimizer{tables: map[string]*tableInSelect{}}
o.parseSelectStmt(stmt.(*ast.SelectStmt))
o.parseTableFromSelectStmt(stmt.(*ast.SelectStmt))
for n, tbl := range o.tables {
if tbl.singleTableSel == nil {
c, ok := tt.join[n]
Expand Down
18 changes: 3 additions & 15 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2621,15 +2621,11 @@ func checkHasJoinCondition(input *RuleHandlerInput) error {
return nil
}

func doesNotJoinTables(tableRefs *ast.Join) bool {
return tableRefs.Left == nil || tableRefs.Right == nil
}

func checkJoinConditionInJoinNode(ctx *session.Context, whereStmt ast.ExprNode, joinNode *ast.Join) (joinTables, hasCondition bool) {
if joinNode == nil {
return false, false
}
if doesNotJoinTables(joinNode) {
if util.DoesNotJoinTables(joinNode) {
// 非JOIN两表的JOIN节点 一般是叶子节点 不检查
return false, false
}
Expand All @@ -2643,10 +2639,10 @@ func checkJoinConditionInJoinNode(ctx *session.Context, whereStmt ast.ExprNode,
}

// 判断该节点是否有显式声明连接条件
if isJoinConditionInOnClause(joinNode) {
if util.IsJoinConditionInOnClause(joinNode) {
return true, true
}
if isJoinConditionInUsingClause(joinNode) {
if util.IsJoinConditionInUsingClause(joinNode) {
return true, true
}
if isJoinConditionInWhereStmt(ctx, whereStmt, joinNode) {
Expand All @@ -2655,14 +2651,6 @@ func checkJoinConditionInJoinNode(ctx *session.Context, whereStmt ast.ExprNode,
return true, false
}

func isJoinConditionInOnClause(joinNode *ast.Join) bool {
return joinNode.On != nil
}

func isJoinConditionInUsingClause(joinNode *ast.Join) bool {
return len(joinNode.Using) > 0
}

func isJoinConditionInWhereStmt(ctx *session.Context, stmt ast.ExprNode, node *ast.Join) bool {
if stmt == nil {
return false
Expand Down
41 changes: 41 additions & 0 deletions sqle/driver/mysql/util/parser_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,44 @@ func ConvertAliasToTable(alias string, tables []*ast.TableSource) (*ast.TableNam
}
return nil, errors.New("can not find table")
}

func DoesNotJoinTables(tableRefs *ast.Join) bool {
return tableRefs.Left == nil || tableRefs.Right == nil
}

func IsJoinConditionInOnClause(joinNode *ast.Join) bool {
return joinNode.On != nil
}

func IsJoinConditionInUsingClause(joinNode *ast.Join) bool {
return len(joinNode.Using) > 0
}

func GetJoinedTableName(joinNode *ast.Join) (*ast.TableName, *ast.TableName) {
var leftTableName, rightTableName *ast.TableName
if tableSource, ok := joinNode.Right.(*ast.TableSource); ok {
rightTableName, _ = tableSource.Source.(*ast.TableName)
}
if tableSource, ok := joinNode.Left.(*ast.TableSource); ok {
leftTableName, _ = tableSource.Source.(*ast.TableName)
}
if leftTableName == nil || rightTableName == nil {
return nil, nil
}
return leftTableName, rightTableName
}

// Support ON Clause like
// ON (column_1, column_2)
// ON table_1.column_1 = COALESCE(table_1.column_1, table_2.column_1)
// ON (table_1.column_1,table_1.column_2 = table_2.column_1,table_2.column_2)
// ON table_1.column_1=table_2.column_2
// ON table_1.column_1=table_2.column_2 AND table_2.column_1=table_3.column_2
func GetJoinedColumnNameExprInOnClause(joinNode *ast.Join) []*ast.ColumnNameExpr {
if !IsJoinConditionInOnClause(joinNode) {
return nil
}
visitor := ColumnNameVisitor{}
joinNode.On.Accept(&visitor)
return visitor.ColumnNameList
}
Loading