Skip to content

Commit

Permalink
Toolchainize //jmh/toolchain:jmh_toolchain_type
Browse files Browse the repository at this point in the history
Adds the jmh toolchain to `scala_toolchains()` and moves
`jmh_repositories()` to `jmh/toolchain/toolchain.bzl` for `rules_java` 8
compatibility. Part of bazelbuild#1482 and bazelbuild#1652.
  • Loading branch information
mbland committed Dec 10, 2024
1 parent 6b3e954 commit 15c8e02
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 116 deletions.
5 changes: 1 addition & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ load("//scala:toolchains.bzl", "scala_toolchains")

scala_toolchains(
fetch_sources = True,
jmh = True,
scala_proto = True,
scalafmt = True,
testing = True,
Expand Down Expand Up @@ -65,10 +66,6 @@ load("//twitter_scrooge:twitter_scrooge.bzl", "twitter_scrooge")

twitter_scrooge()

load("//jmh:jmh.bzl", "jmh_repositories")

jmh_repositories()

# needed for the cross repo proto test
local_repository(
name = "proto_cross_repo_boundary",
Expand Down
87 changes: 10 additions & 77 deletions jmh/BUILD
Original file line number Diff line number Diff line change
@@ -1,77 +1,10 @@
load("//scala:providers.bzl", "declare_deps_provider")
load("//jmh/toolchain:toolchain.bzl", "export_toolchain_deps", "jmh_toolchain")

jmh_toolchain(
name = "jmh_toolchain_impl",
visibility = ["//visibility:public"],
)

toolchain(
name = "jmh_toolchain",
toolchain = ":jmh_toolchain_impl",
toolchain_type = "@io_bazel_rules_scala//jmh/toolchain:jmh_toolchain_type",
)

declare_deps_provider(
name = "jmh_core_provider",
deps_id = "jmh_core",
visibility = ["//visibility:public"],
deps = [
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
],
)

declare_deps_provider(
name = "jmh_classpath_provider",
deps_id = "jmh_classpath",
visibility = ["//visibility:public"],
deps = [
"@io_bazel_rules_scala_net_sf_jopt_simple_jopt_simple",
"@io_bazel_rules_scala_org_apache_commons_commons_math3",
],
)

declare_deps_provider(
name = "benchmark_generator_provider",
deps_id = "benchmark_generator",
visibility = ["//visibility:public"],
deps = [
"//src/java/io/bazel/rulesscala/jar",
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
],
)

declare_deps_provider(
name = "benchmark_generator_runtime_provider",
deps_id = "benchmark_generator_runtime",
visibility = ["//visibility:public"],
deps = [
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
],
)

export_toolchain_deps(
name = "jmh_classpath",
deps_id = "jmh_classpath",
visibility = ["//visibility:public"],
)

export_toolchain_deps(
name = "jmh_core",
deps_id = "jmh_core",
visibility = ["//visibility:public"],
)

export_toolchain_deps(
name = "benchmark_generator",
deps_id = "benchmark_generator",
visibility = ["//visibility:public"],
)

export_toolchain_deps(
name = "benchmark_generator_runtime",
deps_id = "benchmark_generator_runtime",
visibility = ["//visibility:public"],
)
load("//jmh/toolchain:toolchain.bzl", "DEP_PROVIDERS", "export_toolchain_deps")

[
export_toolchain_deps(
name = provider,
deps_id = provider,
visibility = ["//visibility:public"],
)
for provider in DEP_PROVIDERS
]
25 changes: 0 additions & 25 deletions jmh/jmh.bzl
Original file line number Diff line number Diff line change
@@ -1,30 +1,5 @@
load("//scala/private:rules/scala_binary.bzl", "scala_binary")
load("//scala/private:rules/scala_library.bzl", "scala_library")
load(
"//scala:scala_cross_version.bzl",
"default_maven_server_urls",
)
load("//third_party/repositories:repositories.bzl", "repositories")

def jmh_repositories(
maven_servers = default_maven_server_urls(),
overriden_artifacts = {}):
repositories(
for_artifact_ids = [
"io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
"io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
"io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
"io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
"io_bazel_rules_scala_org_ow2_asm_asm",
"io_bazel_rules_scala_net_sf_jopt_simple_jopt_simple",
"io_bazel_rules_scala_org_apache_commons_commons_math3",
],
fetch_sources = False,
maven_servers = maven_servers,
overriden_artifacts = {},
)

native.register_toolchains("@io_bazel_rules_scala//jmh:jmh_toolchain")

def _scala_generate_benchmark(ctx):
# we use required providers to ensure JavaInfo exists
Expand Down
106 changes: 97 additions & 9 deletions jmh/toolchain/toolchain.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,41 @@
load("//scala/private/toolchain_deps:toolchain_deps.bzl", "expose_toolchain_deps")
load("@io_bazel_rules_scala//scala:providers.bzl", _DepsInfo = "DepsInfo")
load("//scala:providers.bzl", "declare_deps_provider", _DepsInfo = "DepsInfo")
load(
"//scala:scala_cross_version.bzl",
"default_maven_server_urls",
_versioned_repositories = "repositories",
)
load("//third_party/repositories:repositories.bzl", "repositories")
load("@io_bazel_rules_scala_config//:config.bzl", "SCALA_VERSION")

DEP_PROVIDERS = [
"jmh_classpath",
"jmh_core",
"benchmark_generator",
"benchmark_generator_runtime",
]

def jmh_artifact_ids():
return [
"io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
"io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
"io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
"io_bazel_rules_scala_org_ow2_asm_asm",
"io_bazel_rules_scala_net_sf_jopt_simple_jopt_simple",
"io_bazel_rules_scala_org_apache_commons_commons_math3",
]

def jmh_repositories(
maven_servers = default_maven_server_urls(),
overriden_artifacts = {}):
repositories(
scala_version = SCALA_VERSION,
for_artifact_ids = jmh_artifact_ids(),
fetch_sources = False,
maven_servers = maven_servers,
overriden_artifacts = overriden_artifacts,
)
native.register_toolchains("@io_bazel_rules_scala_toolchains//jmh:all")

def _jmh_toolchain_impl(ctx):
toolchain = platform_common.ToolchainInfo(
Expand All @@ -11,19 +47,16 @@ jmh_toolchain = rule(
_jmh_toolchain_impl,
attrs = {
"dep_providers": attr.label_list(
default = [
"@io_bazel_rules_scala//jmh:jmh_classpath_provider",
"@io_bazel_rules_scala//jmh:jmh_core_provider",
"@io_bazel_rules_scala//jmh:benchmark_generator_provider",
"@io_bazel_rules_scala//jmh:benchmark_generator_runtime_provider",
],
default = [":%s_provider" % p for p in DEP_PROVIDERS],
providers = [_DepsInfo],
),
},
)

_toolchain_type = "//jmh/toolchain:jmh_toolchain_type"

def _export_toolchain_deps_impl(ctx):
return expose_toolchain_deps(ctx, "@io_bazel_rules_scala//jmh/toolchain:jmh_toolchain_type")
return expose_toolchain_deps(ctx, _toolchain_type)

export_toolchain_deps = rule(
_export_toolchain_deps_impl,
Expand All @@ -32,6 +65,61 @@ export_toolchain_deps = rule(
mandatory = True,
),
},
toolchains = ["@io_bazel_rules_scala//jmh/toolchain:jmh_toolchain_type"],
toolchains = [_toolchain_type],
incompatible_use_toolchain_transition = True,
)

def setup_jmh_toolchain(name):
jmh_toolchain(
name = "%s_impl" % name,
dep_providers = [":%s_provider" % p for p in DEP_PROVIDERS],
visibility = ["//visibility:public"],
)

native.toolchain(
name = name,
toolchain = ":%s_impl" % name,
toolchain_type = Label(_toolchain_type),
visibility = ["//visibility:public"],
)

declare_deps_provider(
name = "jmh_core_provider",
deps_id = "jmh_core",
visibility = ["//visibility:public"],
deps = _versioned_repositories(SCALA_VERSION, [
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
]),
)

declare_deps_provider(
name = "jmh_classpath_provider",
deps_id = "jmh_classpath",
visibility = ["//visibility:public"],
deps = _versioned_repositories(SCALA_VERSION, [
"@io_bazel_rules_scala_net_sf_jopt_simple_jopt_simple",
"@io_bazel_rules_scala_org_apache_commons_commons_math3",
]),
)

declare_deps_provider(
name = "benchmark_generator_provider",
deps_id = "benchmark_generator",
visibility = ["//visibility:public"],
deps = [
"@io_bazel_rules_scala//src/java/io/bazel/rulesscala/jar",
] + _versioned_repositories(SCALA_VERSION, [
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
]),
)

declare_deps_provider(
name = "benchmark_generator_runtime_provider",
deps_id = "benchmark_generator_runtime",
visibility = ["//visibility:public"],
deps = _versioned_repositories(SCALA_VERSION, [
"@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
]),
)
11 changes: 10 additions & 1 deletion scala/toolchains.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Macros to instantiate and register @io_bazel_rules_scala_toolchains"""

load("//jmh/toolchain:toolchain.bzl", "jmh_artifact_ids")
load("//junit:junit.bzl", "junit_artifact_ids")
load("//scala/private:macros/scala_repositories.bzl", "scala_repositories")
load(
Expand Down Expand Up @@ -31,7 +32,8 @@ def scala_toolchains(
scalafmt = False,
scalafmt_default_config_path = ".scalafmt.conf",
scala_proto = False,
scala_proto_enable_all_options = False):
scala_proto_enable_all_options = False,
jmh = False):
"""Instantiates @io_bazel_rules_scala_toolchains and all its dependencies.
Provides a unified interface to configuring rules_scala both directly in a
Expand Down Expand Up @@ -83,6 +85,7 @@ def scala_toolchains(
scala_proto_enable_all_options: whether to instantiate the scala_proto
toolchain with all options enabled; `scala_proto` must also be
`True` for this to take effect
jmh: whether to instantiate the jmh toolchain
"""
scala_repositories(
maven_servers = maven_servers,
Expand Down Expand Up @@ -122,6 +125,11 @@ def scala_toolchains(
id: True
for id in specs2_artifact_ids() + specs2_junit_artifact_ids()
})
if jmh:
artifact_ids_to_fetch_sources.update({
id: False
for id in jmh_artifact_ids()
})

for scala_version in SCALA_VERSIONS:
version_specific_artifact_ids = {}
Expand Down Expand Up @@ -159,6 +167,7 @@ def scala_toolchains(
scalafmt = scalafmt,
scala_proto = scala_proto,
scala_proto_enable_all_options = scala_proto_enable_all_options,
jmh = jmh,
)

def scala_register_toolchains():
Expand Down
9 changes: 9 additions & 0 deletions scala/toolchains_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def _scala_toolchains_repo_impl(repository_ctx):
toolchains["scala"] = _SCALA_TOOLCHAIN_BUILD
if repo_attr.scala_proto:
toolchains["scala_proto"] = _SCALA_PROTO_TOOLCHAIN_BUILD
if repo_attr.jmh:
toolchains["jmh"] = _JMH_TOOLCHAIN_BUILD

testing_build_args = _generate_testing_toolchain_build_file_args(repo_attr)
if testing_build_args != None:
Expand Down Expand Up @@ -78,6 +80,7 @@ _scala_toolchains_repo = repository_rule(
"scalafmt": attr.bool(),
"scala_proto": attr.bool(),
"scala_proto_enable_all_options": attr.bool(),
"jmh": attr.bool(),
},
)

Expand Down Expand Up @@ -201,3 +204,9 @@ declare_deps_provider(
deps = DEFAULT_SCALAPB_WORKER_DEPS,
)
"""

_JMH_TOOLCHAIN_BUILD = """
load("@@{rules_scala_repo}//jmh/toolchain:toolchain.bzl", "setup_jmh_toolchain")
setup_jmh_toolchain(name = "jmh_toolchain")
"""

0 comments on commit 15c8e02

Please sign in to comment.