From 42482fc0e56b2d4998312135ae9d5b00541976c4 Mon Sep 17 00:00:00 2001 From: Yaz Saito Date: Sat, 4 Dec 2021 09:18:17 -0800 Subject: [PATCH] Rewrite sqlvet using the analyzer framework. Use golang.org/x/tools/go/analysis framework. It's simpler, faster. It also allows analyzing individual packages, not the whole program. As a simple benchmark, running the sqlvet (old or new) on a 100kloc source takes about 10s. Analyzing a single file in that source takes <1s. This PR introduces a few incompatible changes: - The commandline format changes. It's now sqlvet [-f sqlvet.toml] packages... The sqlvet.toml file must be in ".", or its location must be explicitly specified by the new "-f" flag. This commandline format is compatible with by most other analyzers, including govet and staticcheck. The old `sqlvet ` command becomes: cd && sqlvet ./... - It removes the support for string concatenation. db.Query("SELECT" + " 1") won't work any more. I personally think this isn't that big deal, we can easily rewrite it using a `raw string`. --- main.go | 65 ++-- pkg/config/config.go | 5 +- pkg/config/config_test.go | 4 +- pkg/vet/gosource.go | 523 +++++++++--------------------- pkg/vet/gosource_internal_test.go | 113 ------- pkg/vet/gosource_test.go | 125 ++----- 6 files changed, 215 insertions(+), 620 deletions(-) delete mode 100644 pkg/vet/gosource_internal_test.go diff --git a/main.go b/main.go index b2da51c..77aec26 100644 --- a/main.go +++ b/main.go @@ -27,9 +27,9 @@ type SQLVet struct { QueryCnt int32 ErrCnt int32 - Cfg config.Config - ProjectRoot string - Schema *schema.Db + Cfg config.Config + Paths []string + Schema *schema.Db } func (s *SQLVet) reportError(format string, a ...interface{}) { @@ -39,33 +39,19 @@ func (s *SQLVet) reportError(format string, a ...interface{}) { // Vet performs static analysis func (s *SQLVet) Vet() { - queries, err := vet.CheckDir( - vet.VetContext{ - Schema: s.Schema, - }, - s.ProjectRoot, - s.Cfg.SqlFuncMatchers, - ) - if err != nil { - cli.Exit(err) - } - - for _, q := range queries { + handleResult := func(q *vet.QuerySite) { atomic.AddInt32(&s.QueryCnt, 1) if q.Err == nil { if cli.Verbose { cli.Show("query detected at %s", q.Position) } - continue + return } // an error in the query is detected if flagErrFormat { - relFilePath, err := filepath.Rel(s.ProjectRoot, q.Position.Filename) - if err != nil { - relFilePath = s.ProjectRoot - } + relFilePath := q.Position.Filename // format ref: https://github.com/reviewdog/reviewdog#errorformat cli.Show( "%s:%d:%d: %v", @@ -84,6 +70,18 @@ func (s *SQLVet) Vet() { cli.Show("") } } + + err := vet.CheckPackages( + vet.VetContext{ + Schema: s.Schema, + }, + s.Paths, + s.Cfg.SqlFuncMatchers, + handleResult, + ) + if err != nil { + cli.Exit(err) + } } // PrintSummary dumps analysis stats into stdout @@ -97,15 +95,19 @@ func (s *SQLVet) PrintSummary() { } // NewSQLVet creates SQLVet for a given project dir -func NewSQLVet(projectRoot string) (*SQLVet, error) { - cfg, err := config.Load(projectRoot) +func NewSQLVet(configPath string, paths []string) (*SQLVet, error) { + cfg, err := config.Load(configPath) if err != nil { return nil, err } var dbSchema *schema.Db if cfg.SchemaPath != "" { - dbSchema, err = schema.NewDbSchema(filepath.Join(projectRoot, cfg.SchemaPath)) + schemaPath := cfg.SchemaPath + if !filepath.IsAbs(cfg.SchemaPath) { + schemaPath = filepath.Join(filepath.Dir(configPath), schemaPath) + } + dbSchema, err = schema.NewDbSchema(schemaPath) if err != nil { return nil, err } @@ -122,17 +124,18 @@ func NewSQLVet(projectRoot string) (*SQLVet, error) { } return &SQLVet{ - Cfg: cfg, - ProjectRoot: projectRoot, - Schema: dbSchema, + Cfg: cfg, + Paths: paths, + Schema: dbSchema, }, nil } func main() { + var configPath string var rootCmd = &cobra.Command{ Use: "sqlvet PATH", Short: "Go fearless SQL", - Args: cobra.ExactArgs(1), + Args: cobra.MinimumNArgs(1), Version: fmt.Sprintf("%s (%s)", version, gitCommit), PreRun: func(cmd *cobra.Command, args []string) { if cli.Verbose { @@ -140,8 +143,7 @@ func main() { } }, Run: func(cmd *cobra.Command, args []string) { - projectRoot := args[0] - s, err := NewSQLVet(projectRoot) + s, err := NewSQLVet(configPath, args) if err != nil { cli.Exit(err) } @@ -157,9 +159,8 @@ func main() { }, } - - rootCmd.PersistentFlags().BoolVarP( - &cli.Verbose, "verbose", "v", false, "verbose output") + rootCmd.PersistentFlags().StringVarP( + &configPath, "config", "f", "./sqlvet.toml", "Path of the config file.") rootCmd.PersistentFlags().BoolVarP( &flagErrFormat, "errorformat", "e", false, "output error in errorformat fromat for easier integration") diff --git a/pkg/config/config.go b/pkg/config/config.go index b5605a1..515ba99 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -3,7 +3,6 @@ package config import ( "io/ioutil" "os" - "path/filepath" "github.com/pelletier/go-toml" @@ -18,9 +17,7 @@ type Config struct { } // Load sqlvet config from project root -func Load(searchPath string) (conf Config, err error) { - configPath := filepath.Join(searchPath, "sqlvet.toml") - +func Load(configPath string) (conf Config, err error) { if _, e := os.Stat(configPath); os.IsNotExist(e) { conf.DbEngine = "postgres" // return default config if not found diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 2eaa65a..b0222f5 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -53,7 +53,7 @@ func (s *ConfigTests) SubTestMultipleMatchers(t *testing.T, fixtures struct { `), 0644) assert.NoError(t, err) - cfg, err := config.Load(fixtures.TmpDir) + cfg, err := config.Load(configPath) assert.NoError(t, err) assert.Equal(t, 2, len(cfg.SqlFuncMatchers)) @@ -73,7 +73,7 @@ func (s *ConfigTests) SubTestNoConfigFile(t *testing.T, fixtures struct { _, e := os.Stat(configPath) assert.True(t, os.IsNotExist(e)) - cfg, err := config.Load(fixtures.TmpDir) + cfg, err := config.Load(configPath) assert.NoError(t, err) assert.Equal(t, config.Config{DbEngine: "postgres"}, cfg) } diff --git a/pkg/vet/gosource.go b/pkg/vet/gosource.go index 6f16116..58791f0 100644 --- a/pkg/vet/gosource.go +++ b/pkg/vet/gosource.go @@ -2,22 +2,18 @@ package vet import ( "errors" - "fmt" "go/ast" "go/constant" "go/token" "go/types" "os" - "path/filepath" - "reflect" "sort" "strings" - "golang.org/x/tools/go/callgraph" - "golang.org/x/tools/go/packages" - "golang.org/x/tools/go/pointer" - "golang.org/x/tools/go/ssa" - "golang.org/x/tools/go/ssa/ssautil" + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/analysis/singlechecker" + "golang.org/x/tools/go/ast/inspector" log "github.com/sirupsen/logrus" @@ -38,11 +34,6 @@ type QuerySite struct { Err error } -type MatchedSqlFunc struct { - SSA *ssa.Function - QueryArgPos int -} - type SqlFuncMatchRule struct { FuncName string `toml:"func_name"` // zero indexed @@ -53,83 +44,6 @@ type SqlFuncMatchRule struct { type SqlFuncMatcher struct { PkgPath string `toml:"pkg_path"` Rules []SqlFuncMatchRule `toml:"rules"` - - pkg *packages.Package -} - -func (s *SqlFuncMatcher) SetGoPackage(p *packages.Package) { - s.pkg = p -} - -func (s *SqlFuncMatcher) PackageImported() bool { - return s.pkg != nil -} - -func (s *SqlFuncMatcher) IterPackageExportedFuncs(cb func(*types.Func)) { - scope := s.pkg.Types.Scope() - for _, scopeName := range scope.Names() { - obj := scope.Lookup(scopeName) - if !obj.Exported() { - continue - } - - fobj, ok := obj.(*types.Func) - if ok { - cb(fobj) - } else { - // check for exported struct methods - switch otype := obj.Type().(type) { - case *types.Signature: - case *types.Named: - for i := 0; i < otype.NumMethods(); i++ { - m := otype.Method(i) - if !m.Exported() { - continue - } - cb(m) - } - case *types.Basic: - default: - log.Debugf("Skipped pkg scope: %s (%s)", otype, reflect.TypeOf(otype)) - } - } - } -} - -func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc { - sqlfuncs := []MatchedSqlFunc{} - - s.IterPackageExportedFuncs(func(fobj *types.Func) { - for _, rule := range s.Rules { - if rule.FuncName != "" && fobj.Name() == rule.FuncName { - sqlfuncs = append(sqlfuncs, MatchedSqlFunc{ - SSA: prog.FuncValue(fobj), - QueryArgPos: rule.QueryArgPos, - }) - // callable matched one rule, no need to go through the rest - break - } - - if rule.QueryArgName != "" { - sigParams := fobj.Type().(*types.Signature).Params() - if sigParams.Len()-1 < rule.QueryArgPos { - continue - } - param := sigParams.At(rule.QueryArgPos) - if param.Name() != rule.QueryArgName { - continue - } - sqlfuncs = append(sqlfuncs, MatchedSqlFunc{ - SSA: prog.FuncValue(fobj), - QueryArgPos: rule.QueryArgPos, - }) - // callable matched one rule, no need to go through the rest - break - } - } - }) - - return sqlfuncs } func handleQuery(ctx VetContext, qs *QuerySite) { @@ -137,13 +51,15 @@ func handleQuery(ctx VetContext, qs *QuerySite) { // e.g. for sqlx, only apply to NamedExec and NamedQuery qs.Query, _, qs.Err = parseutil.CompileNamedQuery( []byte(qs.Query), parseutil.BindType("postgres")) + if qs.Err != nil { return } - var queryParams []QueryParam queryParams, qs.Err = ValidateSqlQuery(ctx, qs.Query) + // log.Printf("QQQQ: %v err=%v", qs.Query, qs.Err) + if qs.Err != nil { return } @@ -213,114 +129,15 @@ func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher { }, }, } - if extraMatchers != nil { - for _, m := range extraMatchers { - tmpm := m - matchers = append(matchers, &tmpm) - } + for _, m := range extraMatchers { + tmpm := m + matchers = append(matchers, &tmpm) } return matchers } -func loadGoPackages(dir string) ([]*packages.Package, error) { - cfg := &packages.Config{ - Mode: packages.NeedName | - packages.NeedFiles | - packages.NeedImports | - packages.NeedDeps | - packages.NeedTypes | - packages.NeedSyntax | - packages.NeedTypesInfo, - Dir: dir, - Env: append(os.Environ(), "GO111MODULE=auto"), - } - dirAbs, err := filepath.Abs(dir) - if err != nil { - return nil, fmt.Errorf("Invalid path: %w", err) - } - pkgPath := dirAbs + "/..." - pkgs, err := packages.Load(cfg, pkgPath) - if err != nil { - return nil, err - } - // return early if any syntax error - for _, pkg := range pkgs { - if len(pkg.Errors) > 0 { - return nil, fmt.Errorf("Failed to load package, %w", pkg.Errors[0]) - } - } - return pkgs, nil -} - -func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) { - queryStr := "" - - switch queryArg := argVal.(type) { - case *ssa.Const: - queryStr = constant.StringVal(queryArg.Value) - case *ssa.Phi: - // TODO: resolve all phi options - // for _, edge := range queryArg.Edges { - // } - log.Debug("TODO(callgraph) support ssa.Phi") - return "", ErrQueryArgTODO - case *ssa.BinOp: - // only support string concat - switch queryArg.Op { - case token.ADD: - lstr, err := extractQueryStrFromSsaValue(queryArg.X) - if err != nil { - return "", err - } - rstr, err := extractQueryStrFromSsaValue(queryArg.Y) - if err != nil { - return "", err - } - queryStr = lstr + rstr - default: - return "", ErrQueryArgUnsupportedType - } - case *ssa.Parameter: - // query call is wrapped in a helper function, query string is passed - // in as function parameter - // TODO: need to trace the caller or add wrapper function to - // matcher config - return "", ErrQueryArgTODO - case *ssa.Extract: - // query string is from one of the multi return values - // need to figure out how to trace string from function returns - return "", ErrQueryArgTODO - case *ssa.Call: - // return value from a function call - // TODO: trace caller function - return "", ErrQueryArgUnsafe - case *ssa.MakeInterface: - // query function takes interface as input - // check to see if interface is converted from a string - switch interfaceFrom := queryArg.X.(type) { - case *ssa.Const: - queryStr = constant.StringVal(interfaceFrom.Value) - default: - return "", ErrQueryArgUnsupportedType - } - case *ssa.Slice: - // function takes var arg as input - - // Type() returns string if the type of X was string, otherwise a - // *types.Slice with the same element type as X. - if _, ok := queryArg.Type().(*types.Slice); ok { - log.Debug("TODO(callgraph) support slice for vararg") - } - return "", ErrQueryArgTODO - default: - return "", ErrQueryArgUnsupportedType - } - - return queryStr, nil -} - -func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool { +func shouldIgnoreNode(pass *analysis.Pass, ignoreNodes []ast.Node, callSitePos token.Pos) bool { if len(ignoreNodes) == 0 { return false } @@ -334,7 +151,7 @@ func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool { } for _, n := range ignoreNodes { - if callSitePos < n.End() && callSitePos > n.Pos() { + if callSitePos < n.End() && callSitePos >= n.Pos() { return true } } @@ -342,136 +159,29 @@ func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool { return false } -func iterCallGraphNodeCallees(ctx VetContext, cgNode *callgraph.Node, prog *ssa.Program, sqlfunc MatchedSqlFunc, ignoreNodes []ast.Node) []*QuerySite { - queries := []*QuerySite{} - - for _, inEdge := range cgNode.In { - callerFunc := inEdge.Caller.Func - if callerFunc.Pkg == nil { - // skip calls from dependencies - continue - } - - callSite := inEdge.Site - callSitePos := callSite.Pos() - if shouldIgnoreNode(ignoreNodes, callSitePos) { - continue - } - - callSitePosition := prog.Fset.Position(callSitePos) - log.Debugf("Validating %s @ %s", sqlfunc.SSA, callSitePosition) - - callArgs := callSite.Common().Args - - absArgPos := sqlfunc.QueryArgPos - if callSite.Common().IsInvoke() { - // interface method invocation. - // In this mode, Value is the interface value and Method is the - // interface's abstract method. Note: an abstract method may be - // shared by multiple interfaces due to embedding; Value.Type() - // provides the specific interface used for this call. - } else { - // "call" mode: when Method is nil (!IsInvoke), a CallCommon - // represents an ordinary function call of the value in Value, - // which may be a *Builtin, a *Function or any other value of - // kind 'func'. - if sqlfunc.SSA.Signature.Recv() != nil { - // it's a struct method call, plus 1 to take receiver into - // account - absArgPos += 1 - } - } - queryArg := callArgs[absArgPos] - - qs := &QuerySite{ - Called: inEdge.Callee.Func.Name(), - Position: callSitePosition, - Err: nil, - } - - if len(callArgs) > absArgPos+1 { - // query function accepts query parameters - paramArg := callArgs[absArgPos+1] - // only support query param as variadic argument for now - switch params := paramArg.(type) { - case *ssa.Const: - // likely nil - case *ssa.Slice: - sliceType := params.X.Type() - switch t := sliceType.(type) { - case *types.Pointer: - elem := t.Elem() - switch e := elem.(type) { - case *types.Array: - // query parameters are passed in as vararg: an array - // of interface - qs.ParameterArgCount = int(e.Len()) - } - } - } - } - - qs.Query, qs.Err = extractQueryStrFromSsaValue(queryArg) - if qs.Err != nil { - switch qs.Err { - case ErrQueryArgUnsupportedType: - log.WithFields(log.Fields{ - "type": reflect.TypeOf(queryArg), - "pos": prog.Fset.Position(callSite.Pos()), - "caller": callerFunc, - "callerPkg": callerFunc.Pkg, - }).Debug(fmt.Errorf("unsupported type in callgraph: %w", qs.Err)) - case ErrQueryArgTODO: - log.WithFields(log.Fields{ - "type": reflect.TypeOf(queryArg), - "pos": prog.Fset.Position(callSite.Pos()), - "caller": callerFunc, - "callerPkg": callerFunc.Pkg, - }).Debug(fmt.Errorf("TODO(callgraph) %w", qs.Err)) - // skip to be supported query type - continue - default: - queries = append(queries, qs) - continue - } - } - - if qs.Query == "" { - continue - } - handleQuery(ctx, qs) - queries = append(queries, qs) - } - - return queries -} - -func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node { +func getSortedIgnoreNodes(pass *analysis.Pass) []ast.Node { ignoreNodes := []ast.Node{} + for _, file := range pass.Files { + cmap := ast.NewCommentMap(pass.Fset, file, file.Comments) + for node, cglist := range cmap { + for _, cg := range cglist { + // Remove `//` and spaces from comment line to get the + // actual comment text. We can't use cg.Text() directly + // here due to change introduced in + // https://github.com/golang/go/issues/37974 + ctext := cg.List[0].Text + if !strings.HasPrefix(ctext, "//") { + continue + } + ctext = strings.TrimSpace(ctext[2:]) - for _, p := range pkgs { - for _, s := range p.Syntax { - cmap := ast.NewCommentMap(p.Fset, s, s.Comments) - for node, cglist := range cmap { - for _, cg := range cglist { - // Remove `//` and spaces from comment line to get the - // actual comment text. We can't use cg.Text() directly - // here due to change introduced in - // https://github.com/golang/go/issues/37974 - ctext := cg.List[0].Text - if !strings.HasPrefix(ctext, "//") { - continue - } - ctext = strings.TrimSpace(ctext[2:]) - - anno, err := ParseComment(ctext) - if err != nil { - continue - } - if anno.Ignore { - ignoreNodes = append(ignoreNodes, node) - log.Tracef("Ignore ast node from %d to %d", node.Pos(), node.End()) - } + anno, err := ParseComment(ctext) + if err != nil { + continue + } + if anno.Ignore { + ignoreNodes = append(ignoreNodes, node) + log.Tracef("Ignore ast node from %d to %d", node.Pos(), node.End()) } } } @@ -484,70 +194,129 @@ func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node { return ignoreNodes } -func CheckDir(ctx VetContext, dir string, extraMatchers []SqlFuncMatcher) ([]*QuerySite, error) { - _, err := os.Stat(filepath.Join(dir, "go.mod")) - if os.IsNotExist(err) { - return nil, errors.New("sqlvet only supports projects using go modules for now.") - } +type sqlFunc struct { + QueryArgPos int +} - pkgs, err := loadGoPackages(dir) - if err != nil { - return nil, err +func isSQLFunc(fobj *types.Func, matchers []*SqlFuncMatcher) *sqlFunc { + if fobj.Pkg() == nil { + // Parse error? + return nil } - log.Debugf("Loaded %d packages: %s", len(pkgs), pkgs) - - ignoreNodes := getSortedIgnoreNodes(pkgs) - log.Debugf("Identified %d queries to ignore", len(ignoreNodes)) - - // check to see if loaded packages imported any package that matches our rules - matchers := getMatchers(extraMatchers) - log.Debugf("Loaded %d matchers, checking imported SQL packages...", len(matchers)) - for _, matcher := range matchers { - for _, p := range pkgs { - v, ok := p.Imports[matcher.PkgPath] - if !ok { - continue - } - // package is imported by at least of the loaded packages - matcher.SetGoPackage(v) - log.Debugf("\t%s imported", matcher.PkgPath) - break + fpkgPath := fobj.Pkg().Path() + for _, m := range matchers { + if m.PkgPath != fpkgPath { + continue } - } - - prog, ssaPkgs := ssautil.Packages(pkgs, 0) - log.Debug("Performaing whole-program analysis...") - prog.Build() + for _, rule := range m.Rules { + if rule.FuncName != "" && fobj.Name() == rule.FuncName { + return &sqlFunc{QueryArgPos: rule.QueryArgPos} + } - // find ssa.Function for matched sqlfuncs from program - sqlfuncs := []MatchedSqlFunc{} - for _, matcher := range matchers { - if !matcher.PackageImported() { - // if package is not imported, then no sqlfunc should be matched - continue + if rule.QueryArgName != "" { + sigParams := fobj.Type().(*types.Signature).Params() + if sigParams.Len()-1 < rule.QueryArgPos { + continue + } + param := sigParams.At(rule.QueryArgPos) + if param.Name() != rule.QueryArgName { + continue + } + return &sqlFunc{QueryArgPos: rule.QueryArgPos} + } } - sqlfuncs = append(sqlfuncs, matcher.MatchSqlFuncs(prog)...) } - log.Debugf("Matched %d sqlfuncs", len(sqlfuncs)) + return nil +} - log.Debugf("Locating main packages from %d packages.", len(ssaPkgs)) - mains := ssautil.MainPackages(ssaPkgs) +// NewAnalyzer creates an analysis.Analyzer for sqlvet. Unlike typical +// Analyzers, it does not report errors directly. Instead it invokes "result" +// for every detected sql query/exec site. +func NewAnalyzer(ctx VetContext, extraMatchers []SqlFuncMatcher, result func(qs *QuerySite)) *analysis.Analyzer { + matchers := getMatchers(extraMatchers) - log.Debug("Building call graph...") - anaRes, err := pointer.Analyze(&pointer.Config{ - Mains: mains, - BuildCallGraph: true, - }) + pos := func(pass *analysis.Pass, pos token.Pos) token.Position { return pass.Fset.Position(pos) } + + run := func(pass *analysis.Pass) (interface{}, error) { + funcCache := map[*types.Func]*sqlFunc{} + ignoredNodes := getSortedIgnoreNodes(pass) + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + inspect.Preorder( + []ast.Node{(*ast.CallExpr)(nil)}, + func(n ast.Node) { + call := n.(*ast.CallExpr) + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return + } + var fobj *types.Func + if s := pass.TypesInfo.Selections[sel]; s != nil { + if s.Kind() == types.MethodVal { + // method (e.g. foo.String()) + fobj = s.Obj().(*types.Func) + } + } else { + // package-qualified function (e.g. fmt.Errorf) + obj := pass.TypesInfo.Uses[sel.Sel] + if obj, ok := obj.(*types.Func); ok { + fobj = obj + } + } + if fobj == nil { + // A cast operator. + return + } - queries := []*QuerySite{} + match, ok := funcCache[fobj] + if !ok { + match = isSQLFunc(fobj, matchers) + funcCache[fobj] = match + } + if match == nil { + return + } - cg := anaRes.CallGraph - for _, sqlfunc := range sqlfuncs { - cgNode := cg.CreateNode(sqlfunc.SSA) - queries = append( - queries, - iterCallGraphNodeCallees(ctx, cgNode, prog, sqlfunc, ignoreNodes)...) + args := call.Args + if len(args) < match.QueryArgPos { + log.Printf("%v: cannot extract extract the query arg #%d", pos(pass, n.Pos()), match.QueryArgPos) + return + } + queryTypeVal, ok := pass.TypesInfo.Types[call.Args[match.QueryArgPos]] + if !ok { + log.Printf("%v: cannot get query type info", pos(pass, n.Pos())) + return + } + if shouldIgnoreNode(pass, ignoredNodes, n.Pos()) { + return + } + qs := &QuerySite{ + Called: fobj.Name(), + Position: pos(pass, n.Pos()), + ParameterArgCount: len(args) - match.QueryArgPos - 1, + } + if queryTypeVal.Value == nil || queryTypeVal.Value.Kind() != constant.String { + qs.Err = ErrQueryArgUnsafe + } else { + qs.Query = constant.Val(queryTypeVal.Value).(string) + } + handleQuery(ctx, qs) + result(qs) + }) + return nil, nil } + return &analysis.Analyzer{ + Name: "sqlvet", // name of the analyzer + Doc: "todo", // documentation + Run: run, // perform your analysis here + Requires: []*analysis.Analyzer{inspect.Analyzer}, // a set of analyzers which must run before the current one. + } +} - return queries, nil +// CheckPackages runs the sqlvet analyzer on the given set of packages. Function +// "result" is invoked for each SQL query or exec site detected by the analyzer. +func CheckPackages(ctx VetContext, paths []string, extraMatchers []SqlFuncMatcher, result func(qs *QuerySite)) error { + analyzer := NewAnalyzer(ctx, extraMatchers, result) + os.Args = append([]string{"unused"}, paths...) + singlechecker.Main(analyzer) + return nil } diff --git a/pkg/vet/gosource_internal_test.go b/pkg/vet/gosource_internal_test.go deleted file mode 100644 index 318cb67..0000000 --- a/pkg/vet/gosource_internal_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package vet - -import ( - "testing" - - "go/constant" - "go/token" - - "github.com/houqp/gtest" - "github.com/stretchr/testify/assert" - "golang.org/x/tools/go/ssa" -) - -type ExtractQueryStrTests struct{} - -func (s *ExtractQueryStrTests) Setup(t *testing.T) {} -func (s *ExtractQueryStrTests) Teardown(t *testing.T) {} -func (s *ExtractQueryStrTests) BeforeEach(t *testing.T) {} -func (s *ExtractQueryStrTests) AfterEach(t *testing.T) {} - -// func (s *ExtractQueryStrTests) SubTestVarArg(t *testing.T) { -// // vararg is parsed as *ssa.Slice -// // eg: (*xorm.io/xorm.Session).Exec -// argVal := ssa.Slice{} -// s, err := extractQueryStrFromArg(ssa) -// assert.NoError(t, err) -// assert.Equal(t, "SELECT name FROM foo WHERE id=1", s) -// } - -func (s *ExtractQueryStrTests) SubTestQueryStringAsInterface(t *testing.T) { - // query string constant is passed in as interface to match query function - // signature - expectedQs := "SELECT name FROM foo WHERE id=2" - argVal := &ssa.MakeInterface{ - X: &ssa.Const{ - Value: constant.MakeString(expectedQs), - }, - } - - qs, err := extractQueryStrFromSsaValue(argVal) - assert.NoError(t, err) - assert.Equal(t, expectedQs, qs) -} - -func (s *ExtractQueryStrTests) SubTestQueryStringAsConstant(t *testing.T) { - expectedQs := "SELECT name FROM foo WHERE id=1" - argVal := &ssa.Const{ - Value: constant.MakeString(expectedQs), - } - - qs, err := extractQueryStrFromSsaValue(argVal) - assert.NoError(t, err) - assert.Equal(t, expectedQs, qs) -} - -func (s *ExtractQueryStrTests) SubTestQueryStringThroughAddBinOp(t *testing.T) { - expectedQs := "SELECT id FROM table" - argVal := &ssa.BinOp{ - Op: token.ADD, - X: &ssa.Const{ - Value: constant.MakeString("SELECT "), - }, - Y: &ssa.Const{ - Value: constant.MakeString("id FROM table"), - }, - } - - qs, err := extractQueryStrFromSsaValue(argVal) - assert.NoError(t, err) - assert.Equal(t, expectedQs, qs) -} - -func (s *ExtractQueryStrTests) SubTestQueryStringThroughNestedAddBinOp(t *testing.T) { - expectedQs := "SELECT id FROM table WHERE id = 1" - argVal := &ssa.BinOp{ - Op: token.ADD, - X: &ssa.BinOp{ - Op: token.ADD, - X: &ssa.Const{ - Value: constant.MakeString("SELECT "), - }, - Y: &ssa.Const{ - Value: constant.MakeString("id FROM table"), - }, - }, - Y: &ssa.Const{ - Value: constant.MakeString(" WHERE id = 1"), - }, - } - - qs, err := extractQueryStrFromSsaValue(argVal) - assert.NoError(t, err) - assert.Equal(t, expectedQs, qs) -} - -func (s *ExtractQueryStrTests) SubTestQueryStringThroughUnsupportedBinOp(t *testing.T) { - argVal := &ssa.BinOp{ - Op: token.AND, - X: &ssa.Const{ - Value: constant.MakeString("SELECT "), - }, - Y: &ssa.Const{ - Value: constant.MakeString("id FROM table"), - }, - } - qs, err := extractQueryStrFromSsaValue(argVal) - assert.Error(t, err) - assert.Equal(t, "", qs) -} - -func TestGoSource(t *testing.T) { - gtest.RunSubTests(t, &ExtractQueryStrTests{}) -} diff --git a/pkg/vet/gosource_test.go b/pkg/vet/gosource_test.go index 1fde3f0..5bfc5fa 100644 --- a/pkg/vet/gosource_test.go +++ b/pkg/vet/gosource_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/houqp/sqlvet/pkg/vet" + "golang.org/x/tools/go/analysis/analysistest" ) type GoSourceTmpDir struct{} @@ -46,32 +47,21 @@ func (s *GoSourceTests) Teardown(t *testing.T) {} func (s *GoSourceTests) BeforeEach(t *testing.T) {} func (s *GoSourceTests) AfterEach(t *testing.T) {} -func (s *GoSourceTests) SubTestInvalidSyntax(t *testing.T, fixtures struct { - TmpDir string `fixture:"GoSourceTmpDir"` -}) { - dir := fixtures.TmpDir - - fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, []byte(` -package main - -func main() { - return 1 -} -`), 0644) - assert.NoError(t, err) - - _, err = vet.CheckDir(vet.VetContext{}, dir, nil) - assert.Error(t, err) +func (s *GoSourceTests) runAnalyzer(t *testing.T, dir, packageName string) []*vet.QuerySite { + var result []*vet.QuerySite + analyzer := vet.NewAnalyzer(vet.VetContext{}, nil, func(r *vet.QuerySite) { + result = append(result, r) + }) + analysistest.Run(t, dir, analyzer, packageName) + return result } func (s *GoSourceTests) SubTestSkipNoneDbQueryCall(t *testing.T, fixtures struct { TmpDir string `fixture:"GoSourceTmpDir"` }) { - dir := fixtures.TmpDir - source := []byte(` -package main + dir, cleanup, err := analysistest.WriteFiles(map[string]string{ + "main/main.go": `package main type Parameter struct {} @@ -99,15 +89,11 @@ func main() { }() }() } - `) - - fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + `}) assert.NoError(t, err) + defer cleanup() - queries, err := vet.CheckDir(vet.VetContext{}, dir, nil) - assert.NoError(t, err) - assert.Equal(t, 0, len(queries)) + _ = s.runAnalyzer(t, dir, "main") } func (s *GoSourceTests) SubTestPkgDatabaseSql(t *testing.T, fixtures struct { @@ -115,8 +101,8 @@ func (s *GoSourceTests) SubTestPkgDatabaseSql(t *testing.T, fixtures struct { }) { dir := fixtures.TmpDir - source := []byte(` -package main + dir, cleanup, err := analysistest.WriteFiles(map[string]string{ + "main/main.go": `package main import ( "context" @@ -151,23 +137,18 @@ func main() { var userInput string tx.Query(fmt.Sprintf("SELECT %s", userInput)) - // string concat - tx.Exec("SELECT " + "7") - staticUserId := "id" - tx.Exec("SELECT " + staticUserId + " FROM foo") + // const string + tx, _ = db.Begin() + const query = "SELECT 7" + tx.ExecContext(ctx, query) } - `) - - fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + `}) assert.NoError(t, err) + defer cleanup() + + queries := s.runAnalyzer(t, dir, "main") - queries, err := vet.CheckDir(vet.VetContext{}, dir, nil) - if err != nil { - t.Fatalf("Failed to load package: %s", err.Error()) - return - } - assert.Equal(t, 6, len(queries)) + assert.Equal(t, 5, len(queries)) sort.Slice(queries, func(i, j int) bool { return queries[i].Position.Offset < queries[j].Position.Offset }) @@ -184,54 +165,14 @@ func main() { // unsafe string assert.Error(t, queries[3].Err) - // string concat assert.NoError(t, queries[4].Err) assert.Equal(t, "SELECT 7", queries[4].Query) - assert.NoError(t, queries[5].Err) - assert.Equal(t, "SELECT id FROM foo", queries[5].Query) -} - -// run sqlvet from parent dir -func (s *GoSourceTests) SubTestCheckRelativeDir(t *testing.T, fixtures struct { - TmpDir string `fixture:"GoSourceTmpDir"` -}) { - dir := fixtures.TmpDir - - source := []byte(` -package main - -func main() { -} - `) - - fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) - assert.NoError(t, err) - - cwd, err := os.Getwd() - assert.NoError(t, err) - parentDir := filepath.Dir(dir) - os.Chdir(parentDir) - defer os.Chdir(cwd) - - queries, err := vet.CheckDir(vet.VetContext{}, filepath.Base(dir), nil) - if err != nil { - t.Fatalf("Failed to load package: %s", err.Error()) - return - } - assert.Equal(t, 0, len(queries)) -} - -func TestGoSource(t *testing.T) { - gtest.RunSubTests(t, &GoSourceTests{}) } func (s *GoSourceTests) SubTestQueryParam(t *testing.T, fixtures struct { TmpDir string `fixture:"GoSourceTmpDir"` }) { - dir := fixtures.TmpDir - - source := []byte(` + const source = ` package main import ( @@ -252,17 +193,13 @@ func main() { db.Query("SELECT 2 FROM foo WHERE id=$1 OR value=$1", 1) } - `) + ` - fpath := filepath.Join(dir, "main.go") - err := ioutil.WriteFile(fpath, source, 0644) + dir, cleanup, err := analysistest.WriteFiles(map[string]string{"main/main.go": source}) assert.NoError(t, err) + defer cleanup() + queries := s.runAnalyzer(t, dir, "main") - queries, err := vet.CheckDir(vet.VetContext{}, dir, nil) - if err != nil { - t.Fatalf("Failed to load package: %s", err.Error()) - return - } assert.Equal(t, 4, len(queries)) sort.Slice(queries, func(i, j int) bool { return queries[i].Position.Offset < queries[j].Position.Offset @@ -284,3 +221,7 @@ func main() { assert.Equal(t, "SELECT 2 FROM foo WHERE id=$1 OR value=$1", queries[3].Query) assert.Equal(t, 1, queries[3].ParameterArgCount) } + +func TestGoSource(t *testing.T) { + gtest.RunSubTests(t, &GoSourceTests{}) +}