Skip to content

Commit

Permalink
Merge pull request #2533 from actiontech/issue-2356-1
Browse files Browse the repository at this point in the history
xml解析不再进行format格式化;返回原始sql字符串
  • Loading branch information
sjjian authored Aug 9, 2024
2 parents e5204f7 + 8235bce commit 5b2de05
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 35 deletions.
22 changes: 7 additions & 15 deletions sqle/api/controller/v1/sql_audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ import (

javaParser "github.com/actiontech/java-sql-extractor/parser"
xmlParser "github.com/actiontech/mybatis-mapper-2-sql"
"github.com/actiontech/mybatis-mapper-2-sql/ast"
"github.com/actiontech/sqle/sqle/api/controller"
"github.com/actiontech/sqle/sqle/common"
"github.com/actiontech/sqle/sqle/dms"
"github.com/actiontech/sqle/sqle/driver"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/errors"
"github.com/actiontech/sqle/sqle/log"
"github.com/actiontech/sqle/sqle/model"
Expand Down Expand Up @@ -106,7 +104,7 @@ func CreateSQLAuditRecord(c echo.Context) error {
SQLsFromFormData: req.Sqls,
}
} else {
sqls, err = getSQLFromFile(c, req.DbType)
sqls, err = getSQLFromFile(c)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
Expand Down Expand Up @@ -324,7 +322,7 @@ func buildOfflineTaskForAudit(userId uint64, dbType string, sqls getSQLFromFileR
}

// todo 此处跳过了不支持的编码格式文件
func getSqlsFromZip(c echo.Context, dbType string) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
file, err := c.FormFile(InputZipFileName)
if err == http.ErrMissingFile {
return nil, nil, false, nil
Expand Down Expand Up @@ -390,7 +388,7 @@ func getSqlsFromZip(c echo.Context, dbType string) (sqlsFromSQLFile []SQLsFromSQ
// parse xml content
// xml文件需要把所有文件内容同时解析,否则会无法解析跨namespace引用的SQL
{
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents, dbType)
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents)
if err != nil {
return nil, nil, false, err
}
Expand All @@ -399,14 +397,8 @@ func getSqlsFromZip(c echo.Context, dbType string) (sqlsFromSQLFile []SQLsFromSQ

return sqlsFromSQLFile, sqlsFromXML, true, nil
}
func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile, dbType string) ([]SQLFromXML, error) {
var allStmtsFromXml []ast.StmtInfo
var err error
if dbType == driverV2.DriverTypePostgreSQL || dbType == driverV2.DriverTypeTBase {
allStmtsFromXml, err = xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery, xmlParser.RestoreOriginSql)
} else {
allStmtsFromXml, err = xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery)
}
func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile) ([]SQLFromXML, error) {
allStmtsFromXml, err := xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery, xmlParser.RestoreOriginSql)
if err != nil {
return nil, fmt.Errorf("parse sqls from xml failed: %v", err)
}
Expand All @@ -423,7 +415,7 @@ func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile, dbType string) ([]SQ
}

// todo 此处跳过了不支持的编码格式文件
func getSqlsFromGit(c echo.Context, dbType string) (sqlsFromSQLFiles, sqlsFromJavaFiles []SQLsFromSQLFile, sqlsFromXMLs []SQLFromXML, exist bool, err error) {
func getSqlsFromGit(c echo.Context) (sqlsFromSQLFiles, sqlsFromJavaFiles []SQLsFromSQLFile, sqlsFromXMLs []SQLFromXML, exist bool, err error) {
// make a temp dir and clean up befor return
dir, err := os.MkdirTemp("./", "git-repo-")
if err != nil {
Expand Down Expand Up @@ -528,7 +520,7 @@ func getSqlsFromGit(c echo.Context, dbType string) (sqlsFromSQLFiles, sqlsFromJa

// parse xml content
// xml文件需要把所有文件内容同时解析,否则会无法解析跨namespace引用的SQL
sqlsFromXMLs, err = parseXMLsWithFilePath(xmlContents, dbType)
sqlsFromXMLs, err = parseXMLsWithFilePath(xmlContents)
if err != nil {
return nil, nil, nil, false, err
}
Expand Down
26 changes: 6 additions & 20 deletions sqle/api/controller/v1/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ import (

dmsV1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
mybatis_parser "github.com/actiontech/mybatis-mapper-2-sql"
"github.com/actiontech/mybatis-mapper-2-sql/ast"
"github.com/actiontech/sqle/sqle/api/controller"
"github.com/actiontech/sqle/sqle/common"
"github.com/actiontech/sqle/sqle/config"
"github.com/actiontech/sqle/sqle/dms"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/errors"
"github.com/actiontech/sqle/sqle/log"
"github.com/actiontech/sqle/sqle/model"
Expand Down Expand Up @@ -109,7 +107,7 @@ const (
ZIPFileExtension = ".zip"
)

func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
func getSQLFromFile(c echo.Context) (getSQLFromFileResp, error) {
// Read it from sql file.
fileName, sqlsFromSQLFile, exist, err := controller.ReadFile(c, InputSQLFileName)
if err != nil {
Expand All @@ -130,13 +128,7 @@ func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
return getSQLFromFileResp{}, err
}
if exist {
var sqls []ast.StmtInfo
var err error
if dbType == driverV2.DriverTypePostgreSQL || dbType == driverV2.DriverTypeTBase {
sqls, err = mybatis_parser.ParseXMLs([]mybatis_parser.XmlFile{{Content: data}}, mybatis_parser.SkipErrorQuery, mybatis_parser.RestoreOriginSql)
} else {
sqls, err = mybatis_parser.ParseXMLs([]mybatis_parser.XmlFile{{Content: data}}, mybatis_parser.SkipErrorQuery)
}
sqls, err := mybatis_parser.ParseXMLs([]mybatis_parser.XmlFile{{Content: data}}, mybatis_parser.SkipErrorQuery, mybatis_parser.RestoreOriginSql)
if err != nil {
return getSQLFromFileResp{}, errors.New(errors.ParseMyBatisXMLFileError, err)
}
Expand All @@ -155,7 +147,7 @@ func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
}

// If mybatis xml file is not exist, read it from zip file.
sqlsFromSQLFiles, sqlsFromXML, exist, err := getSqlsFromZip(c, dbType)
sqlsFromSQLFiles, sqlsFromXML, exist, err := getSqlsFromZip(c)
if err != nil {
return getSQLFromFileResp{}, err
}
Expand All @@ -168,7 +160,7 @@ func getSQLFromFile(c echo.Context, dbType string) (getSQLFromFileResp, error) {
}

// If zip file is not exist, read it from git repository
sqlsFromSQLFiles, sqlsFromJavaFiles, sqlsFromXMLs, exist, err := getSqlsFromGit(c, dbType)
sqlsFromSQLFiles, sqlsFromJavaFiles, sqlsFromXMLs, exist, err := getSqlsFromGit(c)
if err != nil {
return getSQLFromFileResp{}, err
}
Expand Down Expand Up @@ -314,20 +306,14 @@ func CreateAndAuditTask(c echo.Context) error {
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
instance, exist, err := dms.GetInstanceInProjectByName(c.Request().Context(), projectUid, req.InstanceName)
if !exist {
return controller.JSONBaseErrorReq(c, ErrInstanceNotExist)
} else if err != nil {
return controller.JSONBaseErrorReq(c, errors.New(errors.DataConflict, err))
}

if req.Sql != "" {
sqls = getSQLFromFileResp{
SourceType: model.TaskSQLSourceFromFormData,
SQLsFromFormData: req.Sql,
}
} else {
sqls, err = getSQLFromFile(c, instance.DbType)
sqls, err = getSQLFromFile(c)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
Expand Down Expand Up @@ -1001,7 +987,7 @@ func AuditTaskGroupV1(c echo.Context) error {
SQLsFromFormData: req.Sql,
}
} else {
sqls, err = getSQLFromFile(c, dbType)
sqls, err = getSQLFromFile(c)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
Expand Down

0 comments on commit 5b2de05

Please sign in to comment.