Skip to content

Commit

Permalink
caching!
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 13, 2024
1 parent e01fae9 commit c0214eb
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ function ReactantCore.traced_call(f, args...)
track_numbers=(), # TODO: track_numbers?
)

if haskey(Reactant.Compiler.callcache[], cache_key) && false
@info "Cache hit"
else
Reactant.Compiler.callcache[][cache_key] = nothing
@warn Reactant.Compiler.callcache[]
if haskey(Reactant.Compiler.callcache[], cache_key)
# Determine `linear_args`, the vector containing `MLIR.IR.Value`s
# to be passed to the function:
N = length(args)
seen_args = Reactant.OrderedIdDict()
traced_args = ntuple(N) do i
Expand All @@ -163,26 +161,29 @@ function ReactantCore.traced_call(f, args...)
v isa TracedType || continue
push!(linear_args, v.mlir_data)
end
end

f_name = String(gensym(Symbol(f)))
temp = Reactant.make_mlir_fn(
f,
args,
(),
f_name,
false;
no_args_in_result=true,
)

@warn temp


traced_result, ret, linear_result = temp[[3, 6, 9]]
# cache lookup:
(; f_name, mlir_result_types, traced_result) = Reactant.Compiler.callcache[][cache_key]
else
f_name = String(gensym(Symbol(f)))
temp = Reactant.make_mlir_fn(
f,
args,
(),
f_name,
false;
no_args_in_result=true,
)
traced_result, ret, linear_args = temp[[3, 6, 7]]
linear_args = MLIR.IR.Value[v.mlir_data for v in linear_args]
mlir_result_types = [MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)]

Reactant.Compiler.callcache[][cache_key] = (; f_name, mlir_result_types, traced_result)
end

call_op = MLIR.Dialects.func.call(
linear_args;
result_0=[MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)],
result_0=mlir_result_types,
callee=MLIR.IR.FlatSymbolRefAttribute(f_name),
)

Expand All @@ -195,10 +196,10 @@ function ReactantCore.traced_call(f, args...)
toscalar=false,
track_numbers=(),
)
linear_results = TracedType[]
i = 1
for (k, v) in seen_results
v isa TracedType || continue
# this mutates `traced_result`, which is what we want:
v.mlir_data = MLIR.IR.result(call_op, i)
i += 1
end
Expand Down

0 comments on commit c0214eb

Please sign in to comment.