-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kernel-supporting jll #389
base: main
Are you sure you want to change the base?
Changes from 11 commits
22db5df
b6d0615
f6b2238
6121847
d62d01e
c3192d9
e86fac6
6cea3df
e3f4df2
019a690
7a4a403
8fef9ff
fe43909
be19519
b0f679f
1cd21ff
8b5d131
3072670
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -317,7 +317,7 @@ function transpose_val(val) | |||||||||||||||||
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, | ||||||||||||||||||
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, | ||||||||||||||||||
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} | ||||||||||||||||||
@show call_kwargs | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -352,10 +352,34 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c | |||||||||||||||||
|
||||||||||||||||||
fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") | ||||||||||||||||||
# Force public for now while we don't have real users | ||||||||||||||||||
MLIR.IR.rmattr!(func.entry, "sym_visibility") | ||||||||||||||||||
# MLIR.IR.rmattr!(func.entry, "sym_visibility") | ||||||||||||||||||
|
||||||||||||||||||
operands = MLIR.IR.Value[] | ||||||||||||||||||
for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) | ||||||||||||||||||
push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) | ||||||||||||||||||
Comment on lines
+356
to
+357
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
for arg in mlir_args | ||||||||||||||||||
push!(operands, arg) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
owned_regions = MLIR.IR.Region[] | ||||||||||||||||||
successors = MLIR.IR.Block[] | ||||||||||||||||||
attributes = MLIR.IR.NamedAttribute[ | ||||||||||||||||||
MLIR.IR.NamedAttribute("fn", fname), | ||||||||||||||||||
MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases) | ||||||||||||||||||
] | ||||||||||||||||||
|
||||||||||||||||||
location = MLIR.IR.Location() | ||||||||||||||||||
call = MLIR.IR.create_operation( | ||||||||||||||||||
"enzymexla.kern_call", | ||||||||||||||||||
location; | ||||||||||||||||||
operands, | ||||||||||||||||||
owned_regions, | ||||||||||||||||||
successors, | ||||||||||||||||||
attributes, | ||||||||||||||||||
results=restys, | ||||||||||||||||||
result_inference=false, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname)) | ||||||||||||||||||
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) | ||||||||||||||||||
for (i, res) in enumerate(rarrays) | ||||||||||||||||||
res.mlir_data = transpose_val(MLIR.IR.result(call, i)) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
|
@@ -379,7 +403,7 @@ function compiler_cache(ctx::MLIR.IR.Context) | |||||||||||||||||
return cache | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} | ||||||||||||||||||
Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
res = Base.@lock CUDA.cufunction_lock begin | ||||||||||||||||||
# compile the function | ||||||||||||||||||
cache = compiler_cache(MLIR.IR.context()) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -131,6 +131,7 @@ function __init__() | |||||||||||||
end | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
@ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
return nothing | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,6 +19,7 @@ end | |||||||||
oA = collect(1:1:64) | ||||||||||
A = Reactant.to_rarray(oA) | ||||||||||
@show @code_hlo optimize=false square!(A) | ||||||||||
@show @code_hlo optimize=:before_kernel square!(A) | ||||||||||
Comment on lines
21
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
@show @code_hlo square!(A) | ||||||||||
func = @compile square!(A) | ||||||||||
@test all(Array(A) .≈ (oA .* oA)) | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶