Skip to content

Commit

Permalink
delete && insert generate rollback sql func add binary data parse
Browse files Browse the repository at this point in the history
  • Loading branch information
hasa1K committed May 6, 2024
1 parent 8b25c87 commit 91bc0ca
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions sqle/driver/mysql/rollback.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mysql

import (
"database/sql"
"encoding/hex"
"fmt"
"strconv"
"strings"
Expand All @@ -11,6 +12,7 @@ import (

"github.com/pingcap/parser/ast"
_model "github.com/pingcap/parser/model"
parserMysql "github.com/pingcap/parser/mysql"
)

func (i *MysqlDriverImpl) GenerateRollbackSql(node ast.Node) (string, string, error) {
Expand Down Expand Up @@ -449,6 +451,12 @@ func (i *MysqlDriverImpl) generateInsertRollbackSql(stmt *ast.InsertStmt) (strin
return rollbackSql, "", nil
}

// 将二进制字段转化为十六进制字段
func getHexStrFromBytesStr(byteStr string) string {
encode := []byte(byteStr)
return hex.EncodeToString(encode)
}

// generateDeleteRollbackSql generate insert SQL for delete.
func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (string, string, error) {
// not support multi-table syntax
Expand Down Expand Up @@ -497,8 +505,10 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin
values := []string{}

columnsName := []string{}
colNameDefMap := make(map[string]*ast.ColumnDef)
for _, col := range createTableStmt.Cols {
columnsName = append(columnsName, col.Name.Name.String())
colNameDefMap[col.Name.Name.String()] = col
}
for _, record := range records {
if len(record) != len(columnsName) {
Expand All @@ -508,7 +518,13 @@ func (i *MysqlDriverImpl) generateDeleteRollbackSql(stmt *ast.DeleteStmt) (strin
for _, name := range columnsName {
v := "NULL"
if record[name].Valid {
v = fmt.Sprintf("'%s'", record[name].String)
colDef := colNameDefMap[name]
if parserMysql.HasBinaryFlag(colDef.Tp.Flag) {
hexStr := getHexStrFromBytesStr(record[name].String)
v = fmt.Sprintf("X'%s'", hexStr)
} else {
v = fmt.Sprintf("'%s'", record[name].String)
}
}
vs = append(vs, v)
}
Expand Down Expand Up @@ -583,8 +599,10 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin
}
columnsName := []string{}
rollbackSql := ""
colNameDefMap := make(map[string]*ast.ColumnDef)
for _, col := range createTableStmt.Cols {
columnsName = append(columnsName, col.Name.Name.String())
colNameDefMap[col.Name.Name.String()] = col
}
for _, record := range records {
if len(record) != len(columnsName) {
Expand All @@ -610,7 +628,13 @@ func (i *MysqlDriverImpl) generateUpdateRollbackSql(stmt *ast.UpdateStmt) (strin
name := col.Name.Name.O
v := "NULL"
if record[name].Valid {
v = fmt.Sprintf("'%s'", record[name].String)
colDef := colNameDefMap[name]
if parserMysql.HasBinaryFlag(colDef.Tp.Flag) {
hexStr := getHexStrFromBytesStr(record[name].String)
v = fmt.Sprintf("X'%s'", hexStr)
} else {
v = fmt.Sprintf("'%s'", record[name].String)
}
}

if colChanged {
Expand Down

0 comments on commit 91bc0ca

Please sign in to comment.