diff --git a/tools/override-generator/main.go b/tools/override-generator/main.go index b93fa7e44..976ec5950 100644 --- a/tools/override-generator/main.go +++ b/tools/override-generator/main.go @@ -31,18 +31,15 @@ const ( // attribute constants that are used multiple times. const ( - _buildFileGenerationAttr = "build_file_generation" - _buildFileProtoModeAttr = "build_file_proto_mode" - _patchArgsAttr = "patch_args" - _buildDirectivesAttr = "build_directives" - _directivesAttr = "directives" + _buildFileGenerationAttr = "build_file_generation" + _buildFileProtoModeAttr = "build_file_proto_mode" + _patchArgsAttr = "patch_args" + _buildDirectivesAttr = "build_directives" + _directivesAttr = "directives" + _buildFileProtoModeDefault = "default" + _buildFileGenerationModeDefault = "auto" ) -var _defaultValues = map[string]string{ - _buildFileGenerationAttr: "auto", - _buildFileProtoModeAttr: "default", -} - var _mapAttrToOverride = map[string]string{ _buildDirectivesAttr: _gazelleOverride, _buildFileGenerationAttr: _gazelleOverride, @@ -66,6 +63,9 @@ type mainArgs struct { defName string outputFile string gazelleRepoName string + + defaultBuildFileGeneration string + defaultBuildFileProtoMode string } func main() { @@ -91,6 +91,8 @@ func parseArgs(stderr io.Writer, osArgs []string) (*mainArgs, error) { flag.StringVar(&a.defName, "def_name", "", "name of the macro definition") flag.StringVar(&a.outputFile, "output", "", "path to the output file") flag.StringVar(&a.gazelleRepoName, "gazelle_repo_name", "@bazel_gazelle", "name of the gazelle repo to load go_deps, (default: @bazel_gazelle)") + flag.StringVar(&a.defaultBuildFileGeneration, "default_build_file_generation", "auto", "the default value for build_file_generation attribute") + flag.StringVar(&a.defaultBuildFileProtoMode, "default_build_file_proto_mode", "default", "the default value for build_file_proto_mode attribute") flag.Parse(osArgs) if a.macroPath != "" && a.workspace != "" { @@ -137,7 +139,7 @@ func run(a mainArgs, stderr io.Writer) error { // will be deterministic. for _, r := range repos { if r.Kind() == "go_repository" { - repoOverrides := goRepositoryToOverrideSet(r) + repoOverrides := goRepositoryToOverrideSet(r, a.defaultBuildFileGeneration, a.defaultBuildFileProtoMode) outputOverrides = append(outputOverrides, setToOverridesSlice(repoOverrides)...) } } @@ -162,7 +164,7 @@ func run(a mainArgs, stderr io.Writer) error { return nil } -func goRepositoryToOverrideSet(r *rule.Rule) overrideSet { +func goRepositoryToOverrideSet(r *rule.Rule, defaultBuildFileGeneration, defaultBuildFileProtoMode string) overrideSet { // each repo has its own override set, and can't have multiple // duplicate overrides. This set is created to be populated and read set := make(overrideSet) @@ -176,7 +178,9 @@ func goRepositoryToOverrideSet(r *rule.Rule) overrideSet { } attrValue := r.Attr(attr) - if attrValue == nil || attr == _buildFileProtoModeAttr { + + // proto mode and build file generation require special handling. + if attrValue == nil || attr == _buildFileProtoModeAttr || attr == _buildFileGenerationAttr { continue } @@ -194,10 +198,6 @@ func goRepositoryToOverrideSet(r *rule.Rule) overrideSet { attr = k } - if def, ok := _defaultValues[attr]; def == r.AttrString(attr) && ok { - continue - } - if val != nil { switch v := val.(type) { case *build.StringExpr: @@ -216,31 +216,91 @@ func goRepositoryToOverrideSet(r *rule.Rule) overrideSet { set[kind] = override } - // Since "build_file_proto_mode" is added to the "directives", we need - // to run it after the fact to make sure that "directives" is set. - applyBuildFileProtoMode(r, set) + // If the user default doesn't match the global default, but there's a gazelle override, we need to still apply + // it to the individual overrides. + // Also, since "build_file_proto_mode" is added to the "directives", we need + // to apply it last to make sure "directives" is set. + applyBuildFileGeneration(r, set, defaultBuildFileGeneration) + applyBuildFileProtoMode(r, set, defaultBuildFileProtoMode, defaultBuildFileGeneration) return set } -func applyBuildFileProtoMode(r *rule.Rule, set overrideSet) { - if r.Attr(_buildFileProtoModeAttr) == nil { +func applyBuildFileGeneration(r *rule.Rule, set overrideSet, userDefaultGeneration string) { + ruleGeneration := r.AttrString(_buildFileGenerationAttr) + o, ok := set[_gazelleOverride] + if !ok { + if ruleGeneration == "" || ruleGeneration == userDefaultGeneration { + return + } + set[_gazelleOverride] = newGenerationOverride(r.AttrString("importpath"), ruleGeneration) return } - if def, ok := _defaultValues[_buildFileProtoModeAttr]; def == r.AttrString(_buildFileProtoModeAttr) && ok { + if ruleGeneration == "" { + ruleGeneration = userDefaultGeneration + } + + o.SetAttr(_buildFileGenerationAttr, ruleGeneration) + set[_gazelleOverride] = o + return +} + + +func newGenerationOverride(path, ruleGeneration string) *rule.Rule { + override := rule.NewRule(_gazelleOverride, "") + override.SetAttr("path", path) + override.SetAttr(_buildFileGenerationAttr, ruleGeneration) + return override +} + +func applyBuildFileProtoMode(r *rule.Rule, set overrideSet, userDefaultProtoMode, userDefaultGeneration string) { + protoMode := r.AttrString(_buildFileProtoModeAttr) + + // If the gazelle_override doesn't exist. We only need to apply the proto mode + // if it does not match the user default proto mode. + gazelleOverride, ok := set[_gazelleOverride] + if !ok { + if protoMode == "" || protoMode == userDefaultProtoMode { + return + } + + set[_gazelleOverride] = newProtoOverride(r.AttrString("importpath"), protoMode) + + // Since it's a new override, we need to apply build_file_generation again. + applyBuildFileGeneration(r, set, userDefaultGeneration) return } - directive := "gazelle:proto " + r.AttrString(_buildFileProtoModeAttr) - kind := _gazelleOverride - override := rule.NewRule(kind, "") - if o, ok := set[kind]; ok { - override = o + // If the gazelle_override exists, we should apply the override anyway since + // the tag overwrites the defaults. + if protoMode == "" { + protoMode = userDefaultProtoMode } - directives := override.AttrStrings(_directivesAttr) - directives = append(directives, directive) + + safeAppendDirective(gazelleOverride, "gazelle:proto " + protoMode) + set[_gazelleOverride] = gazelleOverride + return +} + +func newProtoOverride(path, protoMode string) *rule.Rule { + override := rule.NewRule(_gazelleOverride, "") + override.SetAttr("path", path) + directives := []string{"gazelle:proto " + protoMode} override.SetAttr(_directivesAttr, directives) - set[kind] = override + return override +} + +func safeAppendDirective(gazelleOverride *rule.Rule, directive string) { + directives := gazelleOverride.AttrStrings(_directivesAttr) + directiveMap := make(map[string]struct{}) + for _, d := range directives{ + directiveMap[d] = struct{}{} + } + if _, ok := directiveMap[directive]; ok { + return + } + directives = append(directives, directive) + gazelleOverride.SetAttr(_directivesAttr, directives) } func setPatchArgs(patchArgs []string, override *rule.Rule) { diff --git a/tools/override-generator/main_test.go b/tools/override-generator/main_test.go index 6eed30513..b0f3fd19a 100644 --- a/tools/override-generator/main_test.go +++ b/tools/override-generator/main_test.go @@ -170,8 +170,236 @@ func TestBzlmodOverride(t *testing.T) { } args := &mainArgs{ - workspace: testWorkspace, - outputFile: filepath.Join(w, "output.bzl"), + workspace: testWorkspace, + outputFile: filepath.Join(w, "output.bzl"), + defaultBuildFileGeneration: "auto", + defaultBuildFileProtoMode: "default", + } + + if err := run(*args, io.Discard); err != nil { + t.Errorf("run() error = %v, want no error", err) + } + + if tt.want == "" { + return + } + + content, err := os.ReadFile(args.outputFile) + if err != nil { + t.Errorf("error reading output file: %v", err) + } + + if !isEqualContent(string(content), tt.want) { + fmt.Fprintf(os.Stderr, "output = %v, want %v", string(content), tt.want) + t.Errorf("output = %v, want %v", string(content), tt.want) + } + }) + } +} + +func TestBzlmodOverrideNewDefaults(t *testing.T) { + tests := []struct { + name string + give string + want string + }{ + { + name: "simple no override", + give: `load("@bazel_gazelle//:deps.bzl", "go_repository") + + go_repository( + name = "com_github_apache_thrift", + build_file_generation = "on", + build_file_proto_mode = "disable", + importpath = "github.com/apache/thrift", + sum = "h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo=", + version = "v0.17.0", + )`, + want: "", + }, + { + name: "simple override", + give: `load("@bazel_gazelle//:deps.bzl", "go_repository") + + go_repository( + name = "com_github_apache_thrift", + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "on", + build_file_proto_mode = "disable_global", + importpath = "github.com/apache/thrift", + sum = "h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo=", + version = "v0.17.0", + )`, + want: `go_deps = use_extension("//:extensions.bzl", "go_deps") + + go_deps.gazelle_override( + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "on", + directives = ["gazelle:proto disable_global"], + path = "github.com/apache/thrift", + )`, + }, + { + name: "module override and gazelle", + give: `load("@bazel_gazelle//:deps.bzl", "go_repository") + + go_repository( + name = "com_github_bazelbuild_bazel_watcher", + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "off", # keep + build_file_proto_mode = "disable", + importpath = "github.com/bazelbuild/bazel-watcher", + patch_args = ["-p1"], + patches = [ + # Remove it after they release this PR https://github.com/bazelbuild/bazel-watcher/pull/627 + "//patches:com_github_bazelbuild_bazel_watcher-go-embed.patch", + ], + sum = "h1:EfJzkMxJuNBGMVdEvkhiW7pAMwhaegbmAMaFCjLjyTw=", + version = "v0.23.7", + )`, + want: `go_deps = use_extension("//:extensions.bzl", "go_deps") + + go_deps.gazelle_override( + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "off", + directives = ["gazelle:proto disable"], + path = "github.com/bazelbuild/bazel-watcher", + ) + + go_deps.module_override( + patch_strip = 1, + patches = [ + # Remove it after they release this PR https://github.com/bazelbuild/bazel-watcher/pull/627 + "//patches:com_github_bazelbuild_bazel_watcher-go-embed.patch", + ], + path = "github.com/bazelbuild/bazel-watcher", + )`, + }, + { + name: "directives and proto args", + give: `go_repository( + name = "com_github_clickhouse_clickhouse_go_v2", + build_directives = [ + "gazelle:resolve go github.com/ClickHouse/clickhouse-go/v2/external @com_github_clickhouse_clickhouse_go_v2//external", + ], + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "auto", + build_file_proto_mode = "disable", + importpath = "github.com/ClickHouse/clickhouse-go/v2", + sum = "h1:Nbl/NZwoM6LGJm7smNBgvtdr/rxjlIssSW3eG/Nmb9E=", + version = "v2.0.12", + )`, + want: `go_deps = use_extension("//:extensions.bzl", "go_deps") + + go_deps.gazelle_override( + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "auto", + directives = [ + "gazelle:resolve go github.com/ClickHouse/clickhouse-go/v2/external @com_github_clickhouse_clickhouse_go_v2//external", + "gazelle:proto disable", + ], + path = "github.com/ClickHouse/clickhouse-go/v2", + )`, + }, + { + name: "archive overrides", + give: `go_repository( + name = "org_golang_x_tools_cmd_goimports", + build_extra_args = [ + "-go_prefix=golang.org/x/tools", + "-exclude=**/testdata", + ], + build_file_generation = "on", + build_file_proto_mode = "disable", + importpath = "golang.org/x/tools/cmd/goimports", + patch_args = ["-p1"], + strip_prefix = "golang.org/x/tools@v0.0.0-20200512131952-2bc93b1c0c88", + sha256 = "4a6497e0bf1f19c8089dd02e7ba1351ba787f434d62971ff14fb627e57914939", + patches = [ + "//patches:org_golang_x_tools_cmd_goimports.patch", + ], + urls = [ + "https://goproxy.uberinternal.com/golang.org/x/tools/@v/v0.0.0-20200512131952-2bc93b1c0c88.zip", + ], + )`, + want: `go_deps = use_extension("//:extensions.bzl", "go_deps") + + go_deps.archive_override( + patch_strip = 1, + patches = [ + "//patches:org_golang_x_tools_cmd_goimports.patch", + ], + path = "golang.org/x/tools/cmd/goimports", + sha256 = "4a6497e0bf1f19c8089dd02e7ba1351ba787f434d62971ff14fb627e57914939", + strip_prefix = "golang.org/x/tools@v0.0.0-20200512131952-2bc93b1c0c88", + urls = [ + "https://goproxy.uberinternal.com/golang.org/x/tools/@v/v0.0.0-20200512131952-2bc93b1c0c88.zip", + ], + ) + + go_deps.gazelle_override( + build_extra_args = [ + "-go_prefix=golang.org/x/tools", + "-exclude=**/testdata", + ], + build_file_generation = "on", + directives = ["gazelle:proto disable"], + path = "golang.org/x/tools/cmd/goimports", + )`, + }, + { + name: "removed duplicate proto mode", + give: `go_repository( + name = "org_golang_x_xerrors", + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "on", + build_directives = [ + "gazelle:proto disable", + ], + importpath = "golang.org/x/xerrors", + patch_args = ["-p1"], + patches = [ + # exposes go_tool_library + "//patches:org_golang_x_xerrors_tool.patch", + ], + sum = "h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=", + version = "v0.0.0-20220907171357-04be3eba64a2", + )`, + want: `go_deps = use_extension("//:extensions.bzl", "go_deps") + + go_deps.gazelle_override( + build_extra_args = ["-go_naming_convention_external=go_default_library"], + build_file_generation = "on", + directives = [ + "gazelle:proto disable", + ], + path = "golang.org/x/xerrors", + ) + + go_deps.module_override( + patch_strip = 1, + patches = [ + # exposes go_tool_library + "//patches:org_golang_x_xerrors_tool.patch", + ], + path = "golang.org/x/xerrors", + )`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := t.TempDir() + testWorkspace := filepath.Join(w, "WORKSPACE") + if err := os.WriteFile(testWorkspace, []byte(removeTabsAndTrimLines(tt.give)), 0644); err != nil { + t.Errorf("error writing test workspace file: %v", err) + } + + args := &mainArgs{ + workspace: testWorkspace, + outputFile: filepath.Join(w, "output.bzl"), + defaultBuildFileGeneration: "on", + defaultBuildFileProtoMode: "disable", } if err := run(*args, io.Discard); err != nil {