-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kernel-supporting jll #389
base: main
Are you sure you want to change the base?
Conversation
@@ -131,6 +131,7 @@ function __init__() | |||
end | |||
end | |||
|
|||
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid | |
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget( | |
"enzymexla_gpu"::Cstring, | |
cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, | |
"CUDA"::Cstring, | |
)::Cvoid |
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, | ||
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, | |
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} | |
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |
args...; | |
convert=Val(false), | |
blocks::CuDim=1, | |
threads::CuDim=1, | |
cooperative::Bool=false, | |
shmem::Integer=0, | |
call_kwargs..., | |
) where {F,tt} |
@show @code_hlo optimize=false square!(A) | ||
@show @code_hlo optimize=:before_kernel square!(A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@show @code_hlo optimize=false square!(A) | |
@show @code_hlo optimize=:before_kernel square!(A) | |
@show @code_hlo optimize = false square!(A) | |
@show @code_hlo optimize = :before_kernel square!(A) |
for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) | ||
push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) | |
push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) | |
for idx in | |
(blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) | |
push!( | |
operands, | |
Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data, | |
) |
push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) | ||
end | ||
for arg in mlir_args | ||
push!(operands, arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
push!(operands, arg) | |
push!(operands, arg) |
MLIR.IR.NamedAttribute("fn", MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))), | ||
MLIR.IR.NamedAttribute("output_operand_aliases", MLIR.IR.Attribute(output_operand_aliases)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
MLIR.IR.NamedAttribute("fn", MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))), | |
MLIR.IR.NamedAttribute("output_operand_aliases", MLIR.IR.Attribute(output_operand_aliases)) | |
MLIR.IR.NamedAttribute("fn", MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))), | |
MLIR.IR.NamedAttribute( | |
"output_operand_aliases", MLIR.IR.Attribute(output_operand_aliases) | |
), |
|
||
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname)) | ||
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) | ||
for (i, res) in enumerate(rarrays) | ||
res.mlir_data = transpose_val(MLIR.IR.result(call, i)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
res.mlir_data = transpose_val(MLIR.IR.result(call, i)) | |
res.mlir_data = transpose_val(MLIR.IR.result(call, i)) |
@@ -379,7 +394,7 @@ function compiler_cache(ctx::MLIR.IR.Context) | |||
return cache | |||
end | |||
|
|||
Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} | |||
Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} | |
Reactant.@reactant_overlay @noinline function CUDA.cufunction( | |
f::F, tt::TT=Tuple{}; kwargs... | |
) where {F,TT} |
@@ -379,7 +394,7 @@ function compiler_cache(ctx::MLIR.IR.Context) | |||
return cache | |||
end | |||
|
|||
Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} | |||
Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} | |||
res = Base.@lock CUDA.cufunction_lock begin | |||
# compile the function | |||
cache = compiler_cache(MLIR.IR.context()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
cache = compiler_cache(MLIR.IR.context()) | |
cache = compiler_cache(MLIR.IR.context()) |
@@ -304,7 +306,28 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) | |||
|
|||
optimize isa Bool && (optimize = ifelse(optimize, :all, :none)) | |||
|
|||
toolkit = "" | |||
if isdefined(Reactant_jll, :ptxas_path) | |||
toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")] | |
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] |
"remove-unnecessary-enzyme-ops", | ||
"enzyme-simplify-math", | ||
opt_passes, | ||
kern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
kern | |
kern, |
@@ -340,6 +363,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) | |||
"remove-unnecessary-enzyme-ops", | |||
"enzyme-simplify-math", | |||
opt_passes, | |||
kern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
kern | |
kern, |
@@ -348,7 +372,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) | |||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) | |||
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) | |||
run_pass_pipeline!( | |||
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math" | |||
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,"*kern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,"*kern | |
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern |
No description provided.