diff --git a/ast/mapper.go b/ast/mapper.go index 9343eb3..7588f30 100644 --- a/ast/mapper.go +++ b/ast/mapper.go @@ -13,6 +13,7 @@ type Mapper struct { SqlNodes map[string]*SqlNode QueryNodeIndex map[string]*QueryNode QueryNodes []*QueryNode + FilePath string } func NewMapper() *Mapper { diff --git a/ast/mappers.go b/ast/mappers.go index b808659..894d264 100644 --- a/ast/mappers.go +++ b/ast/mappers.go @@ -23,9 +23,14 @@ func (s *Mappers) AddMapper(ms ...*Mapper) error { return nil } -func (s *Mappers) GetStmts(skipErrorQuery bool) ([]string, error) { +type StmtInfo struct { + FilePath string + SQL string +} + +func (s *Mappers) GetStmts(skipErrorQuery bool) ([]StmtInfo, error) { ctx := NewContext() - stmts := []string{} + stmts := []StmtInfo{} for _, m := range s.mappers { for id, node := range m.SqlNodes { ctx.Sqls[fmt.Sprintf("%v.%v", m.NameSpace, id)] = node @@ -38,7 +43,12 @@ func (s *Mappers) GetStmts(skipErrorQuery bool) ([]string, error) { if err != nil { return nil, fmt.Errorf("get sqls from mapper failed, namespace: %v, err: %v", m.NameSpace, err) } - stmts = append(stmts, stmt...) + for _, sql := range stmt { + stmts = append(stmts, StmtInfo{ + FilePath: m.FilePath, + SQL: sql, + }) + } } return stmts, nil } diff --git a/parser.go b/parser.go index 1604a1b..cc73398 100644 --- a/parser.go +++ b/parser.go @@ -28,12 +28,17 @@ func ParseXML(data string) (string, error) { return stmt, nil } -// ParseXMLs is a parser for parse all query in several XML files to []string one by one; +type XmlFile struct { + FilePath string + Content string +} + +// ParseXMLs is a parser for parse all query in several XML files to []ast.StmtInfo one by one; // you can set `skipErrorQuery` true to ignore invalid query. -func ParseXMLs(data []string, skipErrorQuery bool) ([]string, error) { +func ParseXMLs(data []XmlFile, skipErrorQuery bool) ([]ast.StmtInfo, error) { ms := ast.NewMappers() - for i := range data { - r := strings.NewReader(data[i]) + for _, data := range data { + r := strings.NewReader(data.Content) d := xml.NewDecoder(r) n, err := parse(d) if err != nil { @@ -56,6 +61,7 @@ func ParseXMLs(data []string, skipErrorQuery bool) ([]string, error) { return nil, errors.New("the mapper is not found") } } + m.FilePath = data.FilePath err = ms.AddMapper(m) if err != nil && !skipErrorQuery { return nil, fmt.Errorf("add mapper failed: %v", err)