diff --git a/ast/mapper.go b/ast/mapper.go
index ed1446e..bb7e705 100644
--- a/ast/mapper.go
+++ b/ast/mapper.go
@@ -66,3 +66,20 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) {
}
return strings.TrimSuffix(buff.String(), "\n"), nil
}
+
+func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
+ var stmts []string
+ ctx.Sqls = m.SqlNodes
+ for _, a := range m.QueryNodes {
+ data, err := a.GetStmt(ctx)
+ if err == nil {
+ stmts = append(stmts, data)
+ continue
+ }
+ if skipErrorQuery {
+ continue
+ }
+ return nil, err
+ }
+ return stmts, nil
+}
diff --git a/parser.go b/parser.go
index 5a0c7b5..a9dc29a 100644
--- a/parser.go
+++ b/parser.go
@@ -2,12 +2,14 @@ package parser
import (
"encoding/xml"
+ "fmt"
"io"
"strings"
"github.com/actiontech/mybatis-mapper-2-sql/ast"
)
+// ParseXML is a parser for parse all query in XML to string.
func ParseXML(data string) (string, error) {
r := strings.NewReader(data)
d := xml.NewDecoder(r)
@@ -25,6 +27,29 @@ func ParseXML(data string) (string, error) {
return stmt, nil
}
+// ParseXMLQuery is a parser for parse all query in XML to []string one by one;
+// you can set `skipErrorQuery` true to ignore invalid query.
+func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {
+ r := strings.NewReader(data)
+ d := xml.NewDecoder(r)
+ n, err := parse(d, nil)
+ if err != nil {
+ return nil, err
+ }
+ if n == nil {
+ return nil, nil
+ }
+ m, ok := n.(*ast.Mapper)
+ if !ok {
+ return nil, fmt.Errorf("the mapper is not found")
+ }
+ stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery)
+ if err != nil {
+ return nil, err
+ }
+ return stmts, nil
+}
+
func parse(d *xml.Decoder, start *xml.StartElement) (node ast.Node, err error) {
if start != nil {
node, err = scan(start)
diff --git a/parser_test.go b/parser_test.go
index 3427155..08164ec 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -572,3 +572,263 @@ func TestParserSQLRefIdNotFound(t *testing.T) {
t.Errorf("actual error is [%s]", err.Error())
}
}
+
+func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) {
+ actual, err := ParseXMLQuery(xmlData, skipError)
+ if err != nil {
+ t.Errorf("parse error: %v", err)
+ return
+ }
+ if len(actual) != len(expect) {
+ t.Errorf("the length of actual is not the same as the length of expected, actual length is %d, expect is %d",
+ len(actual), len(expect))
+ return
+ }
+ for i := range actual {
+ if actual[i] != expect[i] {
+ t.Errorf("\nexpect[%d]: [%s]\nactual[%d]: [%s]", i, expect, i, actual)
+ }
+ }
+
+}
+
+func TestParserQueryFullFile(t *testing.T) {
+ testParserQuery(t, false,
+ `
+
+
+
+
+ fruits
+
+
+ WHERE
+ category = #{category}
+
+
+ FROM
+
+
+
+
+
+
+
+
+
+ UPDATE
+ fruits
+
+
+ category = #{category},
+
+
+ price = ${price},
+
+
+ WHERE
+ name = #{name}
+
+
+
+
+ INSERT INTO
+ fruits
+ (
+ name,
+ category,
+ price
+ )
+ VALUES
+
+ (
+ #{fruit.name},
+ #{fruit.category},
+ ${fruit.price}
+ )
+
+
+
+`,
+ []string{
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=? AND `price`>?",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=?",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE 1=1 AND `category`=? AND `price`=? AND `name`=\"Fuji\"",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" OR `price`=200",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" AND `price`=?",
+ "UPDATE `fruits` SET `category`=?, `price`=? WHERE `name`=?",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name`=? AND `category`=? AND `price`=? AND `category`=\"apple\"",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" AND (`name`=? OR `name`=?)",
+ "INSERT INTO `fruits` (`name`,`category`,`price`) VALUES (?,?,?),(?,?,?)",
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?",
+ })
+}
+
+func TestParserQueryHasInvalidQuery(t *testing.T) {
+ _, err := ParseXMLQuery(
+ `
+
+
+ *
+
+
+
+`, false)
+ if err == nil {
+ t.Errorf("expect has error, but no error")
+ }
+ if err.Error() != "sql someinclude2 is not exist" {
+ t.Errorf("actual error is [%s]", err.Error())
+ }
+}
+
+func TestParserQueryHasInvalidQueryButSkip(t *testing.T) {
+ testParserQuery(t, true,
+ `
+
+
+ *
+
+
+
+`, []string{
+ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?",
+ })
+}