Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel-supporting jll #389

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
1 change: 0 additions & 1 deletion deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/service/cpu/simple_orc_jit.h"

#include "xla/python/ifrt/hlo/hlo_program.h"
#include "llvm/MC/TargetRegistry.h"
Expand Down
4 changes: 4 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ cc_library(
"-Wl,-exported_symbol,_ifrt_*",
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
"-Wl,-exported_symbol,_EnzymeGPUCustomCall",
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down Expand Up @@ -469,6 +470,9 @@ cc_library(
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt/cpu:cpu_client",

"@xla//xla/service:metrics_proto_cc",
"@xla//xla/service:metrics_proto_cc_impl",

"@xla//xla/service/cpu:cpu_compiler",
"@xla//xla/stream_executor/tpu:tpu_on_demand_compiler",
"@xla//xla/stream_executor/tpu:tpu_executor",
Expand Down
36 changes: 22 additions & 14 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "f6587e37ff7298f2a1a273b08c24d69fca7ff30f"
ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand All @@ -19,6 +19,27 @@ http_archive(
urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)],
)


# Hedron's Compile Commands Extractor for Bazel
# https://github.com/hedronvision/bazel-compile-commands-extractor
http_archive(
name = "hedron_compile_commands",

# Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here.
# Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README).
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",
# When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..."
)
load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")
hedron_compile_commands_setup()
load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive")
hedron_compile_commands_setup_transitive()
load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive")
hedron_compile_commands_setup_transitive_transitive()
load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive")
hedron_compile_commands_setup_transitive_transitive_transitive()

load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256", "XLA_PATCHES")

XLA_PATCHES = XLA_PATCHES + [
Expand Down Expand Up @@ -57,19 +78,6 @@ sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.
# """,
]

http_archive(
name = "rules_cc",
sha256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4",
strip_prefix = "rules_cc-c8c38f8c710cbbf834283e4777916b68261b359c",
urls = [
"https://github.com/bazelbuild/rules_cc/archive/c8c38f8c710cbbf834283e4777916b68261b359c.tar.gz",
],
)

load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")

rules_cc_dependencies()

LLVM_TARGETS = select({
"@bazel_tools//src/conditions:windows": ["AMDGPU", "NVPTX"],
"@bazel_tools//src/conditions:darwin": [],
Expand Down
33 changes: 29 additions & 4 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ function transpose_val(val)
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
end

Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
Comment on lines +320 to 321
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
args...;
convert=Val(false),
blocks::CuDim=1,
threads::CuDim=1,
cooperative::Bool=false,
shmem::Integer=0,
call_kwargs...,
) where {F,tt}

@show call_kwargs

Expand Down Expand Up @@ -352,10 +352,35 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c

fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name")
# Force public for now while we don't have real users
MLIR.IR.rmattr!(func.entry, "sym_visibility")
# MLIR.IR.rmattr!(func.entry, "sym_visibility")

op_ty_results = IR.Type[result_0...,]
operands = MLIR.IR.Value[]
for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z)
push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data)
end
for arg in mlir_ir_args
push!(operands, arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(operands, arg)
push!(operands, arg)

end
owned_regions = MLIR.IR.Region[]
successors = MLIR.IR.Block[]
attributes = MLIR.IR.NamedAttribute[
MLIR.IR.namedattribute("fn", fname),
MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases)
]

location = MLIR.IR.Location()
call = MLIR.IR.create_operation(
"stablehlo.custom_call",
location;
operands,
owned_regions,
successors,
attributes,
results=restys,
result_inference=false,
)

call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname))
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
for (i, res) in enumerate(rarrays)
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
res.mlir_data = transpose_val(MLIR.IR.result(call, i))

end
Expand Down
19 changes: 18 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,22 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
optimize isa Bool && (optimize = ifelse(optimize, :all, :none))

if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
join(
[
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
"lower-kernel"
],
',',
),
)
elseif optimize === :before_kernel
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
Expand Down Expand Up @@ -340,6 +356,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
"lower-kernel"
],
',',
),
Expand All @@ -348,7 +365,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math"
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,lower-kernel"
)
elseif optimize !== :none
error("Invalid optimize option: $(Meta.quot(optimize))")
Expand Down
1 change: 1 addition & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget(
"enzymexla_gpu"::Cstring,
cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid},
"CUDA"::Cstring,
)::Cvoid

return nothing
end

Expand Down
34 changes: 1 addition & 33 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,4 @@ end

const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))

@testset "Reactant.jl Tests" begin
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core"
@safetestset "Layout" include("layout.jl")
@safetestset "Tracing" include("tracing.jl")
@safetestset "Basic" include("basic.jl")
@safetestset "Autodiff" include("autodiff.jl")
@safetestset "Complex" include("complex.jl")
@safetestset "Broadcast" include("bcast.jl")
@safetestset "Struct" include("struct.jl")
@safetestset "Closure" include("closure.jl")
@safetestset "Compile" include("compile.jl")
@safetestset "Buffer Donation" include("buffer_donation.jl")
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
@testset "Neural Networks" begin
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
@safetestset "Flux.jl Integration" include("nn/flux.jl")
if Sys.islinux()
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
@safetestset "Lux Integration" include("nn/lux.jl")
end
end
end
end
include("cuda.jl")
Loading