Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interp2 #365

Merged
merged 79 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from 75 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
7159756
WIP: kernels
Nov 29, 2024
f891bde
more files
Nov 29, 2024
14174db
fix
Nov 29, 2024
e401e4a
wip
Dec 5, 2024
e7bc318
wqtmp
Dec 5, 2024
ae6c7c6
wip
Dec 6, 2024
b8f206f
inc
Dec 7, 2024
c414601
continuing
Dec 7, 2024
9138a36
wip
Dec 8, 2024
c2ca4cb
more work
Dec 9, 2024
da328d3
inf rec
Dec 9, 2024
ad4d05b
fix
Dec 9, 2024
750661a
overload working
Dec 9, 2024
12dec6c
continuing
Dec 10, 2024
a6cd104
continuing
Dec 10, 2024
db6e37b
push
Dec 10, 2024
8831f4d
fix `call_with_reactant_generator` for Julia 1.11 (#359)
jumerckx Dec 10, 2024
e2ffe87
conversion
Dec 10, 2024
364823a
continuing
wsmoses Dec 11, 2024
ff729ce
Cleanup
wsmoses Dec 11, 2024
3bd5608
Apply suggestions from code review
wsmoses Dec 11, 2024
5e33afb
Delete test/cuda.jl
wsmoses Dec 11, 2024
17c2f72
fixup
wsmoses Dec 11, 2024
4807a79
Apply suggestions from code review
wsmoses Dec 11, 2024
51286bd
fix apply
wsmoses Dec 11, 2024
bd40b69
indep of change
wsmoses Dec 11, 2024
5b3329c
minor fix in name
wsmoses Dec 11, 2024
4af4a00
Update utils.jl
wsmoses Dec 11, 2024
8379f05
Interp take 2
wsmoses Dec 11, 2024
246ec4e
continuing adentures
wsmoses Dec 11, 2024
9a669ef
delcode
wsmoses Dec 11, 2024
623ff38
fix
wsmoses Dec 12, 2024
df3e27c
tmp
wsmoses Dec 12, 2024
bda8912
make
wsmoses Dec 12, 2024
fd92864
fix
wsmoses Dec 12, 2024
e74173b
cleanup
wsmoses Dec 12, 2024
c71942c
continuing
wsmoses Dec 12, 2024
1fa3c93
more working
wsmoses Dec 13, 2024
07fb856
further simplify
wsmoses Dec 13, 2024
72533ff
fx
wsmoses Dec 13, 2024
503f1ff
more improvements
wsmoses Dec 13, 2024
d302cd9
minus show
wsmoses Dec 13, 2024
59a648a
less prints
wsmoses Dec 13, 2024
d55e4e8
even fewer
wsmoses Dec 13, 2024
db50a37
confusion
wsmoses Dec 13, 2024
284094a
tmp
wsmoses Dec 13, 2024
82584f8
force clean
wsmoses Dec 13, 2024
00776da
force oc
wsmoses Dec 13, 2024
ad784b3
clean
wsmoses Dec 13, 2024
3aba6a2
Merge branch 'main' into interp2
wsmoses Dec 13, 2024
e90096b
Rewrite
wsmoses Dec 14, 2024
9e1fe6c
fixup
wsmoses Dec 14, 2024
0982c09
fix
wsmoses Dec 14, 2024
996d60a
fix
wsmoses Dec 14, 2024
681107f
fix
wsmoses Dec 14, 2024
3a35f5c
fixup
wsmoses Dec 14, 2024
eea1dba
fix
wsmoses Dec 14, 2024
c65948d
wip
wsmoses Dec 14, 2024
2ff00ad
safe prints
wsmoses Dec 14, 2024
caad928
fix
wsmoses Dec 14, 2024
7547699
fix
wsmoses Dec 14, 2024
c425ccb
stackoverflow
wsmoses Dec 14, 2024
a6b52f5
cleanup
wsmoses Dec 14, 2024
b17e75f
dyindex
wsmoses Dec 14, 2024
2cff76e
rt
wsmoses Dec 14, 2024
1d0cb8e
continue
wsmoses Dec 14, 2024
f4349a9
clean
wsmoses Dec 14, 2024
a4ae31a
fix
wsmoses Dec 14, 2024
0dbe20f
fix
wsmoses Dec 14, 2024
1c45d7e
fix
wsmoses Dec 14, 2024
1ffb366
fix
wsmoses Dec 14, 2024
3887575
fixup
wsmoses Dec 14, 2024
873e46b
fix
wsmoses Dec 14, 2024
21244db
fix
wsmoses Dec 14, 2024
70c3951
capture oc
wsmoses Dec 14, 2024
f839a0b
compile perf
wsmoses Dec 14, 2024
8073ecd
v1.11 fix
jumerckx Dec 14, 2024
585c485
other way 'round
jumerckx Dec 14, 2024
0c56d35
formatting
jumerckx Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ Adapt = "4"
ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13.21"
Enzyme = "0.13.22"
EnzymeCore = "0.8.8"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
ReactantCore = "0.1.2"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Statistics = "1.10"
Expand Down
10 changes: 10 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
return wrap(res);
}

#include "llvm/IRReader/IRReader.h"
extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) {
LLVMContext Context;
SMDiagnostic Err;
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
mlir::MLIRContext &context = *unwrap(cctx);
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release();
return wrap(res);
}


/* Note that this */
extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) {
Expand Down
2 changes: 2 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",

"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:AArch64AsmParser",
"@llvm-project//llvm:AArch64CodeGen",
Expand Down
19 changes: 11 additions & 8 deletions ext/ReactantNNlibExt.jl
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ using Reactant:
Ops,
TracedRArray,
AnyTracedRArray,
materialize_traced_array,
MLIR,
TracedRNumber,
TracedRNumber

using Reactant.TracedUtils:
materialize_traced_array,
get_mlir_data,
set_mlir_data!

using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu

Expand Down Expand Up @@ -238,9 +241,9 @@ function NNlib.batched_mul!(
if size(x, 3) != size(y, 3)
B = max(size(x, 3), size(y, 3))
if size(x, 3) == 1
x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
elseif size(y, 3) == 1
y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
end
end

Expand All @@ -250,9 +253,9 @@ function NNlib.batched_mul!(
if size(x, 1) != size(y, 1)
B = max(size(x, 1), size(y, 1))
if size(x, 1) == 1
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
elseif size(y, 1) == 1
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
end
end

Expand All @@ -270,7 +273,7 @@ end
function NNlib.pad_constant(
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
) where {T,N}
value = Reactant.promote_to(TracedRNumber{T}, value)
value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value)
low = [i[1] for i in pad]
high = [i[2] for i in pad]
interior = [0 for i in pad]
Expand Down Expand Up @@ -329,7 +332,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
res isa TracedRNumber && (res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
res isa TracedRNumber && (res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
res isa TracedRNumber &&
(res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))

return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
Expand Down
3 changes: 2 additions & 1 deletion ext/ReactantStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ReactantStatisticsExt

using Reactant: AnyTracedRArray, materialize_traced_array
using Reactant: AnyTracedRArray
using Reactant.TracedUtils: materialize_traced_array
using Statistics: Statistics

function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}
Expand Down
7 changes: 4 additions & 3 deletions ext/ReactantYaoBlocksExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module ReactantYaoBlocksExt

using Reactant
using Reactant.TracedUtils: broadcast_to_size
using YaoBlocks

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M = broadcast_to_size(zero(T), (2, 2))
c = cos(R.theta / 2)
s = -im * sin(R.theta / 2)
M[1, 1] = c
Expand All @@ -19,7 +20,7 @@ end
function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M = broadcast_to_size(zero(T), (2, 2))
c = cos(R.theta / 2)
s = sin(R.theta / 2)
M[1, 1] = c
Expand All @@ -32,7 +33,7 @@ end
function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M = broadcast_to_size(zero(T), (2, 2))
x = exp(im * R.theta / 2)
M[1, 1] = conj(x)
M[2, 2] = x
Expand Down
2 changes: 1 addition & 1 deletion lib/ReactantCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactantCore"
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
version = "0.1.2"
version = "0.1.3"

[deps]
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
Expand Down
4 changes: 2 additions & 2 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ function trace_for(mod, expr)

all_syms = Expr(:tuple, counter, external_syms...)
args_init = Expr(
:tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
:tuple, :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:tuple, :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
:tuple,
:(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)),
external_syms...,

)

reactant_code_block = quote
let args = $(args_init)
cond_fn =
$(all_syms) -> begin
local num_iters = div($limit - $start, $step, RoundDown)
local num_iters = Reactant.promote_to(
local num_iters = Reactant.TracedUtils.promote_to(
Reactant.TracedRNumber{Int64}, num_iters
)
$counter < num_iters + 1
Expand Down
7 changes: 6 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
return Reactant.make_mlir_fn(f, args, (), "main", true)
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
end

Expand Down Expand Up @@ -779,6 +779,11 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
return register_thunk(fname, body)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(f, args; client=nothing, optimize=true, sync=false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Reactant.@reactant_override function Reactant.Compiler.compile(f, args; client=nothing, optimize=true, sync=false)
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)

return f
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

Expand Down
10 changes: 9 additions & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ end
function Base.convert(
::Type{T}, X::WrappedConcreteRArray{ElType,N}
) where {T<:Array,ElType,N}
fn = compile(materialize_traced_array, (X,))
fn = compile(TracedUtils.materialize_traced_array, (X,))
return convert(Array, fn(X))
end
Base.Array(x::AnyConcreteRArray) = convert(Array, x)
Expand Down Expand Up @@ -345,3 +345,11 @@ end

buffer_on_cpu(::Any) = true
buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer)

function Ops.constant(x::ConcreteRArray; kwargs...)
return Ops.constant(Base.convert(Array, x); kwargs...)
end

function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T}
return Ops.constant(Base.convert(T, x); kwargs...)
end
16 changes: 8 additions & 8 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function ReactantCore.traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.make_mlir_fn(
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn(
true_fn,
args,
(),
Expand All @@ -12,7 +12,7 @@ function ReactantCore.traced_if(
construct_function_without_args=true,
)

(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.make_mlir_fn(
(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn(
false_fn,
args,
(),
Expand All @@ -36,16 +36,16 @@ function ReactantCore.traced_if(
returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.")
elseif tr isa MissingTracedValue
push!(result_types, MLIR.IR.type(fr.mlir_data))
push!(linear_results, new_traced_value(false_linear_results[i]))
push!(linear_results, TracedUtils.new_traced_value(false_linear_results[i]))
push!(true_block_insertions, (i => linear_results[end]))
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, new_traced_value(true_linear_results[i]))
push!(linear_results, TracedUtils.new_traced_value(true_linear_results[i]))
push!(false_block_insertions, (i => linear_results[end]))
end
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, new_traced_value(tr))
push!(linear_results, TracedUtils.new_traced_value(tr))
end
end

Expand Down Expand Up @@ -82,13 +82,13 @@ function ReactantCore.traced_while(
# We promote all incoming args (is there a better way to do this?)
traced_args = [
if v isa Number && !(v isa TracedType)
Reactant.promote_to(TracedRNumber{typeof(v)}, v)
Reactant.TracedUtils.promote_to(TracedRNumber{typeof(v)}, v)
else
v
end for v in args
]

(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn(
(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
Expand All @@ -99,7 +99,7 @@ function ReactantCore.traced_while(
do_transpose=false,
)

(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn(
(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
Expand Down
Loading
Loading