From e6ea655d42e8f78f5e11e4063e8e042f66fdf606 Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Fri, 24 May 2024 16:15:37 +0800 Subject: [PATCH 1/2] do not use single quotes to wrap pk data --- sqle/driver/mysql/rollback.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sqle/driver/mysql/rollback.go b/sqle/driver/mysql/rollback.go index 7e5647dbf7..0004c94d94 100644 --- a/sqle/driver/mysql/rollback.go +++ b/sqle/driver/mysql/rollback.go @@ -11,6 +11,7 @@ import ( "github.com/actiontech/sqle/sqle/errors" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" _model "github.com/pingcap/parser/model" parserMysql "github.com/pingcap/parser/mysql" ) @@ -626,7 +627,7 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin colChanged = true if isPk { isPkChanged = true - pkValue = util.ExprFormat(l.Expr) + pkValue = restore(l.Expr) } } } @@ -647,7 +648,7 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin } if isPk { if isPkChanged { - where = append(where, fmt.Sprintf("%s = '%s'", name, pkValue)) + where = append(where, fmt.Sprintf("%s = %s", name, pkValue)) } else { where = append(where, fmt.Sprintf("%s = %s", name, v)) @@ -660,6 +661,18 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin return rollbackSql, "", nil } +// 还原抽象语法树节点至SQL +func restore(node ast.Node) (sql string) { + var buf strings.Builder + rc := format.NewRestoreCtx(format.DefaultRestoreFlags, &buf) + + if err := node.Restore(rc); err != nil { + return + } + sql = buf.String() + return +} + // getRecords select all data which will be update or delete. func (i *MysqlDriverImpl) getRecords(tableName *ast.TableName, tableAlias string, where ast.ExprNode, order *ast.OrderByClause, limit int64) ([]map[string]sql.NullString, error) { From 98d8f32ad6b777505f7d416c1f82a84b8c65a91f Mon Sep 17 00:00:00 2001 From: bianyucheng Date: Fri, 24 May 2024 17:53:49 +0800 Subject: [PATCH 2/2] not need to use restore func --- sqle/driver/mysql/rollback.go | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/sqle/driver/mysql/rollback.go b/sqle/driver/mysql/rollback.go index 0004c94d94..9d4a7def67 100644 --- a/sqle/driver/mysql/rollback.go +++ b/sqle/driver/mysql/rollback.go @@ -11,7 +11,6 @@ import ( "github.com/actiontech/sqle/sqle/errors" "github.com/pingcap/parser/ast" - "github.com/pingcap/parser/format" _model "github.com/pingcap/parser/model" parserMysql "github.com/pingcap/parser/mysql" ) @@ -627,7 +626,7 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin colChanged = true if isPk { isPkChanged = true - pkValue = restore(l.Expr) + pkValue = util.ExprFormat(l.Expr) } } } @@ -661,18 +660,6 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin return rollbackSql, "", nil } -// 还原抽象语法树节点至SQL -func restore(node ast.Node) (sql string) { - var buf strings.Builder - rc := format.NewRestoreCtx(format.DefaultRestoreFlags, &buf) - - if err := node.Restore(rc); err != nil { - return - } - sql = buf.String() - return -} - // getRecords select all data which will be update or delete. func (i *MysqlDriverImpl) getRecords(tableName *ast.TableName, tableAlias string, where ast.ExprNode, order *ast.OrderByClause, limit int64) ([]map[string]sql.NullString, error) {