diff --git a/ast/mapper.go b/ast/mapper.go index ed1446e..bb7e705 100644 --- a/ast/mapper.go +++ b/ast/mapper.go @@ -66,3 +66,20 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) { } return strings.TrimSuffix(buff.String(), "\n"), nil } + +func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) { + var stmts []string + ctx.Sqls = m.SqlNodes + for _, a := range m.QueryNodes { + data, err := a.GetStmt(ctx) + if err == nil { + stmts = append(stmts, data) + continue + } + if skipErrorQuery { + continue + } + return nil, err + } + return stmts, nil +} diff --git a/parser.go b/parser.go index 5a0c7b5..a9dc29a 100644 --- a/parser.go +++ b/parser.go @@ -2,12 +2,14 @@ package parser import ( "encoding/xml" + "fmt" "io" "strings" "github.com/actiontech/mybatis-mapper-2-sql/ast" ) +// ParseXML is a parser for parse all query in XML to string. func ParseXML(data string) (string, error) { r := strings.NewReader(data) d := xml.NewDecoder(r) @@ -25,6 +27,29 @@ func ParseXML(data string) (string, error) { return stmt, nil } +// ParseXMLQuery is a parser for parse all query in XML to []string one by one; +// you can set `skipErrorQuery` true to ignore invalid query. +func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) { + r := strings.NewReader(data) + d := xml.NewDecoder(r) + n, err := parse(d, nil) + if err != nil { + return nil, err + } + if n == nil { + return nil, nil + } + m, ok := n.(*ast.Mapper) + if !ok { + return nil, fmt.Errorf("the mapper is not found") + } + stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery) + if err != nil { + return nil, err + } + return stmts, nil +} + func parse(d *xml.Decoder, start *xml.StartElement) (node ast.Node, err error) { if start != nil { node, err = scan(start) diff --git a/parser_test.go b/parser_test.go index 3427155..08164ec 100644 --- a/parser_test.go +++ b/parser_test.go @@ -572,3 +572,263 @@ func TestParserSQLRefIdNotFound(t *testing.T) { t.Errorf("actual error is [%s]", err.Error()) } } + +func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) { + actual, err := ParseXMLQuery(xmlData, skipError) + if err != nil { + t.Errorf("parse error: %v", err) + return + } + if len(actual) != len(expect) { + t.Errorf("the length of actual is not the same as the length of expected, actual length is %d, expect is %d", + len(actual), len(expect)) + return + } + for i := range actual { + if actual[i] != expect[i] { + t.Errorf("\nexpect[%d]: [%s]\nactual[%d]: [%s]", i, expect, i, actual) + } + } + +} + +func TestParserQueryFullFile(t *testing.T) { + testParserQuery(t, false, + ` + + + + + fruits + + + WHERE + category = #{category} + + + FROM + + + + + + + + + + UPDATE + fruits + + + category = #{category}, + + + price = ${price}, + + + WHERE + name = #{name} + + + + + INSERT INTO + fruits + ( + name, + category, + price + ) + VALUES + + ( + #{fruit.name}, + #{fruit.category}, + ${fruit.price} + ) + + + +`, + []string{ + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=? AND `price`>?", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=?", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE 1=1 AND `category`=? AND `price`=? AND `name`=\"Fuji\"", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" OR `price`=200", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" AND `price`=?", + "UPDATE `fruits` SET `category`=?, `price`=? WHERE `name`=?", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name`=? AND `category`=? AND `price`=? AND `category`=\"apple\"", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" AND (`name`=? OR `name`=?)", + "INSERT INTO `fruits` (`name`,`category`,`price`) VALUES (?,?,?),(?,?,?)", + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?", + }) +} + +func TestParserQueryHasInvalidQuery(t *testing.T) { + _, err := ParseXMLQuery( + ` + + + * + + + +`, false) + if err == nil { + t.Errorf("expect has error, but no error") + } + if err.Error() != "sql someinclude2 is not exist" { + t.Errorf("actual error is [%s]", err.Error()) + } +} + +func TestParserQueryHasInvalidQueryButSkip(t *testing.T) { + testParserQuery(t, true, + ` + + + * + + + +`, []string{ + "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?", + }) +}