Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 13, 2024
1 parent 998ef8a commit e01fae9
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,36 @@ end
function ReactantCore.traced_call(f, args...)
# TODO: caching!
cache_key = make_tracer(
OrderedIdDict(),
Reactant.OrderedIdDict(),
(f, args...),
(),
CallCache;
toscalar=false,
track_numbers=(), # TODO: track_numbers?
)

if haskey(Reactant.Compiler.callcache[], cache_key)
if haskey(Reactant.Compiler.callcache[], cache_key) && false
@info "Cache hit"
else
Reactant.Compiler.callcache[][cache_key] = nothing
@warn Reactant.Compiler.callcache[]
N = length(args)
seen_args = Reactant.OrderedIdDict()
traced_args = ntuple(N) do i
return make_tracer(
seen_args,
args[i],
(),
TracedTrack;
toscalar=false,
track_numbers=(),
)
end
linear_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_args
v isa TracedType || continue
push!(linear_args, v.mlir_data)
end
end

f_name = String(gensym(Symbol(f)))
Expand All @@ -158,20 +175,32 @@ function ReactantCore.traced_call(f, args...)
no_args_in_result=true,
)

@warn temp


traced_result, ret, linear_result = temp[[3, 6, 9]]

call_op = MLIR.Dialects.func.call(
[a.mlir_data for a in args];
linear_args;
result_0=[MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)],
callee=MLIR.IR.FlatSymbolRefAttribute(f_name),
)

@assert length(linear_result) == MLIR.IR.noperands(ret) "Expected $(MLIR.IR.noperands(ret)) results, got $(length(linear_result))."

for i in 1:length(linear_result)
# mutating the TracedRArrays in linear_result, changes
# them in traced_result as well:
linear_result[i].mlir_data = MLIR.IR.result(call_op, i)
linear_result[i].paths=()
seen_results = Reactant.OrderedIdDict()
traced_result = make_tracer(
seen_results,
traced_result,
(),
TracedSetPath;
toscalar=false,
track_numbers=(),
)
linear_results = TracedType[]
i = 1
for (k, v) in seen_results
v isa TracedType || continue
v.mlir_data = MLIR.IR.result(call_op, i)
i += 1
end

return traced_result
Expand Down

0 comments on commit e01fae9

Please sign in to comment.