Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix count affect row #1904

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions sqle/driver/mysql/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/actiontech/sqle/sqle/driver/mysql/executor"
"github.com/actiontech/sqle/sqle/log"
"github.com/actiontech/sqle/sqle/utils"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/format"
Expand All @@ -26,7 +27,7 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe

var newNode ast.Node
var affectRowSql string
var hasGroupByOrGroupByAndHavingBoth bool
var cannotConvert bool

// 语法规则文档
// select: https://dev.mysql.com/doc/refman/8.0/en/select.html
Expand All @@ -36,9 +37,8 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
switch stmt := node.(type) {
case *ast.SelectStmt:
isGroupByAndHavingBothExist := stmt.GroupBy != nil && stmt.Having != nil
// 包含group by或者group by和having都存在的select语句
if stmt.GroupBy != nil || isGroupByAndHavingBothExist {
hasGroupByOrGroupByAndHavingBoth = true
if stmt.GroupBy != nil || isGroupByAndHavingBothExist || stmt.Limit != nil {
cannotConvert = true
}

newNode = getSelectNodeFromSelect(stmt)
Expand All @@ -62,9 +62,10 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
return 0, ErrUnsupportedSqlType
}

// 存在group by或者group by和having都存在的select语句,无法转换为select count语句
// 使用子查询 select count(*) from (输入的sql) as t的方式来获取影响行数
if hasGroupByOrGroupByAndHavingBoth {
// 1. 存在group by或者group by和having都存在的select语句,无法转换为select count语句
// 2. SELECT COUNT(1) FROM test LIMIT 10,10 类型的SQL结果集为空
// 已上两种情况,使用子查询 select count(*) from (输入的sql) as t的方式来获取影响行数
if cannotConvert {
// 移除后缀分号,避免sql语法错误
trimSuffix := strings.TrimRight(originSql, ";")
affectRowSql = fmt.Sprintf("select count(*) from (%s) as t", trimSuffix)
Expand All @@ -82,16 +83,23 @@ func GetAffectedRowNum(ctx context.Context, originSql string, conn *executor.Exe
// 避免在客户机器上执行不符合预期的sql语句
err = checkSql(affectRowSql)
if err != nil {
return 0, err
return 0, fmt.Errorf("check sql(%v) failed, origin sql(%v), err: %v", affectRowSql, originSql, err)
}

_, row, err := conn.Db.QueryWithContext(ctx, affectRowSql)
if err != nil {
return 0, err
}

// 如果下发的 SELECT COUNT(1) 的SQL,返回的结果集为空, 则返回0
// 例: SELECT COUNT(1) FROM test LIMIT 10,10 结果集为空
if len(row) == 0 {
log.NewEntry().Errorf("affected row sql(%v) result row count is 0", affectRowSql)
return 0, nil
}

if len(row) != 1 {
return 0, errors.New("affectRowSql error")
return 0, fmt.Errorf("affected row sql(%v) result row count(%v) is not 1", affectRowSql, len(row))
}

affectCount, err := strconv.ParseInt(row[0][0].String, 10, 64)
Expand Down
Loading